Matching Networks for One Shot Learning
2 minute read

Motivation

딥러닝 학습은 여전히 많은 데이터를 요구하며 stochastic gradient descent 을 통해 weight 을 여러번 업데이트해야하기 때문에 시간이 오래걸린다.

Differences

본 논문에서는 딥러닝 학습이 오래걸리는 이유로 모델의 모수적인 성질 (parametric aspect) 을 꼽는다. 이 때문에 모델의 파라미터가 데이터를 천천히 학습해나간다는 것이다. 이러한 점에 착안해 논문에서는 존재하는 비모수적 방법론을 응용해 적은 양의 데이터로도 빠르게 학습할 수 있는 방법을 제시하고자한다. 비모수적 모델은 새로운 데이터 포인트를 빠르게 학습을 하면서도 기존의 학습 내용이 극심하게 사라지지는 않는다는 강점을 가진다. 따라서 본 논문에서는 모수적 방법론의 장점인 Common example 에 대한 높은 일반화 성능을 유지하면서도, 비모수적 방법론을 이용해 새로운 데이터에 대해서는 빠른 습득을 할 수 있는 모델을 제안하고자 한다.

Method

논문에서는 다음의 두가지 방법론을 새롭게 제안한다.

  1. 모델 구조
  2. 학습 전략

1. Model Architecture

Matching Network 는 크게 보면 external memory 를 추가한 뉴럴넷 구조에 속한다. 이러한 구조에는 attention mechanism 을 적용한 seq2seq, memory network, pointer network 가 속한다. 이 모델들은 당장 풀어야하는 task 에 유용한 정보를 담은 memory matrix 을 활용할 수 있는 neural attention mechanism 을 지닌다는 공통점을 가진다.

Matching Network 는 학습이 완료되고 나면 파라미터를 수정하지 않고도 test label 을 예측할 수 있다. 구체적으로는 $k$ 개의 (이미지, 레이블) 데이터를 가지고 있는 support set $\mathcal{S} = \{(x_i, y_i)\}_{i=1}^k$ 를 이용해 $P(\hat{y} \vert \hat{x}, \mathcal{S})$ 을 예측한다.

결국 kNN 과 유사하게 support set 에 있는 $x_i$ 의 label 을 이용해 예측을 하는데, 각 $x_i$ 와 $\hat{x}$ 와의 유사도에 따라 대응되는 label $y_i$ 가 가중치를 받는 구조이다.

Attention Kernel

그렇다면 $x_i$ 와 $\hat{x}$ 의 유사도를 어떻게 구할것인가의 문제가 남는다. $a(\hat{x}, x_i)$ 를 정의해줘야한다는 점에서 Matching Network 는 metric learning 과 관계가 있다. 여기서는 단순한 식을 사용해 $a(\hat{x}, x_i)$ 를 정의한다.

Full context embeddings

예측이 전체 support set $\mathcal{S}$ 에 기반해 이뤄지긴 하지만 $f$, $g$ 을 사용해 임베딩을 생성하는 당시에는 $\mathcal{S}$ 가 고려되지 않는다. 이 문제를 해결하기 위해 $g(x_i)$, $f(\hat{x})$ 를 $g(x_i, \mathcal{S})$, $f(\hat{x}, \mathcal{S})$ 로 수정하는 방법이 사용되었다.

이 방법을 사용하면 먼저 support set 내에 $x_i$ 와 상당히 유사한 $x_j$ 가 존재할때, $g$ 는 이 정보를 고려하여 $x_i$ 를 임베딩할 수 있다. 또 $\hat{x}$ 를 임베딩하는 $f$ 도 $\mathcal{S}$ 을 고려하도록 다음과 같이 LSTM 구조를 사용하였다.

2. Training Strategy

Training episode

  1. $\mathcal{L}$ 을 $T$ 로부터 샘플링한다. (e.g. $\mathcal{L} = \{\text{cats}, \text{dogs}\}$)
  2. $\mathcal{L}$ 을 이용해 support set $\mathcal{S}$ 와 batch $\mathcal{B}$ 를 샘플링한다.
  3. Support set $\mathcal{S}$ 를 이용해 $\mathcal{B}$ 의 샘플에 대해 레이블을 예측한다.

이 순서를 따르고 다음의 objective function 을 최대화하는 방향으로 파라미터를 학습한다.

Results

Comments

Recent Posts

Why are Sequence-to-Sequence Models So Dull?
Variational Autoregressive Decoder for Neural Response Generation
Content Preserving Text Generation with Attribute Controls
Pointer Networks
Get To The Point: Summarization with Pointer-Generator Networks