Content Preserving Text Generation with Attribute Controls
4 minute read

Motivation

언어는 discrete 하고 sequential 한 특성을 지니고 있기 때문에 이미지 분야에 성공적으로 적용된 conditional generative model 이나 style transfer 기법들을 똑같이 적용할 수 없다는 어려움이 있다. 본 논문에서는 GAN 구조와 back translation 기법을 응용해 어떤 문장과 특징(attribute)이 주어졌을때 그 조건에 맞도록 주어진 문장을 수정할 수 있는 generative model 을 제안한다.

model

Differences

제시된 모델은 기존 방법론과 크게 두가지 차이점을 가지고 있다.

첫번째는 생성된 문장의 질을 평가하는 방법에서의 차별성이다. 본 논문에서는 다음의 세가지 평가척도를 사용하였다.

  1. 생성된 문장이 새 조건을 제외한 부분에서 원 문장의 의미를 보존하고 있는지 (Content preservation)
  2. 새로 주어진 조건을 잘 반영하고 있는지 (Attribute compatibility)
  3. 문법적으로 어색하지는 않은지 (Fluency)

유사한 연구를 진행했던 기존 방법들(Shen et al., 2017)은 주로 2번, attribute 에 대한 adversarial discriminator 만을 고려하였다. 본 논문에서는 이 세가지 기준이 전부 반영되도록 loss function 을 구성하였고, 셋 모두에서 기존의 방법론에 비해 좋은 결과를 보일 수 있었다.

두번째는 여러개의 특성(cf. Table 5 의 tense, voice, mood, negation)을 동시에 조작할 수 있다는 점이다. 기존의 연구는 주로 하나의 특성을 조작하는 방법이 주를 이뤘다.

그 외에 언어의 discrete 한 특성하에서 back propagation 을 돕기위해 사용되어온 soft-sampling 을 버리고 hard-sampling 을 택했다는 점도 특징이다. Training time 과 inference time 에서 나타나는 learning dynamics 의 차이를 줄이기 위해 이러한 방법을 택했다고 설명한다. ($y$ 문장을 생성하는 generator 가 gradient 를 다른 경로를 통해서도 전달받아 학습될 수 있기 때문에 soft sampling 을 버릴 수 있었던 것으로 이해했다.)

Method

모델의 구조는 크게 Generator $G$ 와 Discriminator $D$ 로 이루어져있다.

특히 Generator $G$ 는 encoder-decoder 구조를 가지고 있는데, $G$ 가 생성한 문장이 가져야하는 바람직한 특성으로는 위에서 언급되었던 Content compatibility 와 Attribute compatibility 를 들 수 있다. 따라서 $G$ 의 인코더는 문장 $x$ 로부터 content 에 해당하는 부분만 잘 뽑아낸 $z_x = G_{enc}( \cdot \vert z_x, l^\prime)$ 를 만든다. 그러고나면 디코더 $G_{dec}$ 의 input 으로 $z_x$ 와 임의의 조건 $l^\prime$ 이 들어간다 : $y \sim p_G(\cdot \vert z_x, l^\prime)$ . 여기서 생선된 문장이 $x$ 의 content 는 잘 보존하면서도 $l^\prime$ 의 조건을 잘 만족하는 것이 목표이다.

Content compatibility

$x$ 의 내용을 잘 보존하기 위해서 기존 문헌에서 사용되었던 두 종류의 reconstruction loss 가 이용하었다.

1. Autoencoding loss

$z_x = G_{enc}(x)$ 가 $x$ 의 핵심적인 내용을 담고 있다면 $x$ 의 attribute 레이블 $l$ 에 대해 $G(\cdot \vert z_x, l)$ 이 높은 확률값을 가져야한다. 이를 반영한 auto-encoding loss 를 다음과 같이 정의할 수 있다.

2. Back-translation loss

임의의 attribute 벡터 $l^\prime$ 가 존재할 때, $x$ 와 $l^\prime$ 으로 생성한 문장을 $y$ 라고 하자 : $y \sim p_G(\cdot \vert z_x, l^\prime)$ . 이때 모델이 잘 학습되었다면 $y$ 는 $x$ 의 내용(content)을 잘 보존하고 있을 것이다. 이를 활용해서 $z_y = G_{enc}(y)$ 를 만들 수 있는데, $z_y$ 가 $z_x$ 의 내용을 잘 보존하고 있다는 전제하에 $p_G(z \vert z_y, l)$ 은 높은 확률을 가질 것이다. 이렇게 back-translation loss 를 정의할 수 있다.

Interpolated reconstruction loss

위에서 언급된 두 종류의 loss 는 각기 다른 단점을 가지고 있다.

  Autoencoding loss Back translation loss
Loss function $\mathcal{L}^{ae} (x,l) = -\log{p_G}{(x \vert z_x, l)}$ $\mathcal{L}^{bt} (x,l) = -\log{p_G}{(x \vert z_y, l)}$
단점 모델이 주로 input 을 똑같이 복사하게 되어 $z$ 가 의미있는 정보를 담지 못한다. $y$가 학습 초반에는 $x$ 와 충분히 비슷하지않아서 $G$의 학습이 어려워진다.

이러한 단점을 해결하기 위해 본 논문에서는 ground truth sentence $x$ 와 generated sentence $y$ 의 latent representation 간의 interpolation 을 이용한다.

Content embedding 을 interpolation 에 의해서 구하면 디코더가 원 문장($z_x$)에 대해 가지고 있었던 의존성을 제거할 수 있다. 인코더 또한 $z_x$ 와 $z_y$ 가 유사하게 생성되도록 학습된다.

Attribute compatibility

위에서는 생성된 문장이 원문장의 content 를 보존하고 있는지에 초점을 맞췄다면 여기서는 생성된 문장이 1. 진짜 문장 같은지, 2. 주어진 attribute 에 부합하는지를 확인한다.

먼저 주어진 문장이 진짜인지, 생성된 가짜 문장인지를 판별하는 discriminator 구조를 떠올려볼 수 있다. Discriminator $D$ 의 input 으로는 $G_{dec}$ 의 hidden state sequence $h_x$ 또는 $h_y$ 와 attribute $l$ 이 사용되고, projection discriminator 의 구조를 따른다.

이를 이용해 다음과 같이 adversarial loss 를 만들 수 있다.

여기서 discriminator 가 attribute $l$ 이나 hidden state 중 한 요소만을 고려해서 real / fake 를 구분하는 것을 방지하기 위해 $(x, l^\prime)$ 이라는 fake pair 을 추가적으로 고려해 loss function 에서 얘도 fake 로 판별하도록 한다.

$\mathcal{L}^{int}$ 과 $\mathcal{L}^{adv}$ 을 종합해 최종 loss function 은 다음과 같이 정의된다. 실험 당시에는 $\lambda$ 를 $\{0.5, 1, 1.5\}$ 사이에서 튜닝해 최적의 값을 선택하였다.

model

Results

본 논문에서 제시한 모델은 Ctrl-gen (Hu et al., 2017) 과 Cross-align (Shen et al., 2017) 에 비해 attribute accuracy, content compatibility, fluency 모두에서 우월한 성능을 보였다.

model

보면 Ctrl-gen 은 생성된 문장이 content 정보를 반영하지 못하는 단점을 가지고 있고 Cross-align 은 relevant 하지만 문법적으로 어색한 문장을 만든다는 단점이 있다.

Recent Posts

Why are Sequence-to-Sequence Models So Dull?
Variational Autoregressive Decoder for Neural Response Generation
Matching Networks for One Shot Learning
Pointer Networks
Get To The Point: Summarization with Pointer-Generator Networks