Matching Networks for One Shot Learning論文解讀

這篇文章在元學習領域筆記重要,之前一直想讀,這次正好有機會就把它給刷了。

本篇論文屬於小樣本學習領域,但是本篇論文中的Matching Networks常被用於與Meta-learning任務中的方法進行比較。這篇論文出自Google DeepMind團隊,發表於2016年。

1 Motivation

人類可以可以通過非常少量的樣本學習到一個新的概念,比如一個小孩子看完一張長頸鹿的照片之後就認識了長頸鹿這個動物。但是最好的深度學習模型依然需要成百上千的例子來學習到一個新的概念。因此本文就考慮如何通過一個樣本就讓深度學習模型學會一個新概念。

傳統上訓練出一個模型需要使用很多樣本進行很多次的參數更新,因此作者認爲可以使用一個無參數的模型。參考KNN這種度量式的做法,作者將有參數的模型和無參數的模型進行了結合。

2 Contribution

  1. 在模型層面上,作者提出了一個Matching Networks, 將注意力機制和記憶機制引入快速學習任務中。
  2. 在訓練流程上,作者訓練模型時遵循了一個很簡單的規則,即測試和訓練條件必須匹配。作者在訓練時僅用每個類別中很少的樣本進行訓練,因爲在測試時也使用的是很少的樣本。(即訓練條件和測試條件匹配)

3 Method

3.1 Model Architecture

在這裏插入圖片描述
gθg_{\theta}fθf_\theta分別是對訓練數據和測試數據的編碼函數。Matching Networks可以簡潔表示爲計算一個無標籤樣本的標籤爲y^\hat{y}的概率,這個計算方法跟KNN很像,相當於是加權後的KNN:
P(y^x^,S)=i=1ka(x^,xi)yi P(\hat{y}|\hat{x},S) = \sum^{k}_{i=1}a(\hat{x},x_i)y_i
其中xi,yix_i,y_i是輸入的支撐集(support set)中的樣本S={(xi,yi)}i=1kS = \{(x_i,y_i)\}^k_{i=1}aa類似於注意力機制中的核函數,用來度量x^xi\hat{x},x_i的匹配度。
a(x^,xi)=ec(f(x^),g(xi))j=1kec(f(x^),g(xj)) a(\hat{x},x_i) = \frac{e^{c(f(\hat{x}),g(x_i))}}{\sum^k_{j=1}e^{c(f(\hat{x}),g(x_j))}}
在這裏公式ff定義了對測試樣本的編碼方式,對於Figure 1 中的gθg_{\theta};公式gg定義了對訓練樣本的編碼方式,對應於Figure 1 中的fθf_\theta。這個公式先對f(x^),g(xi)f(\hat{x}),g(x_i)計算了一個餘弦距離,然後在做一個softmax歸一化。

3.2 Training Function g

gg是一個BiLSTM,它的輸入是xix_i和支撐集SS
g(xi,S)=hi+hi+g(xi) g(x_i, S) = \overrightarrow{h_i} + \overleftarrow{h_i}+ g'(x_i)

其中$g'(x_i)$是一個神經網絡,比如VGG或者Inception。

3.3 Test Function f

ff是一個迭代了K步的 LSTM,它的輸出是LSTM最後輸出的隱狀態hh。即f(x^,S)=hkf(\hat{x},S)=h_k,其中hkh_k由(3)式決定:

在這裏插入圖片描述
其中,ff'是一個embedding函數,比如一個CNN。

3.4 Training procedure

給定一個有k個樣本的支撐集S={(xi,yi)}i=1kS = \{(x_i,y_i)\}^k_{i=1},對測試樣本 x^\hat{x}分類爲 CS(x^)C_S(\hat{x})。定義SCS(x^)S \rightarrow C_S(\hat{x}) 這一映射爲P(y^x^,S)=i=1ka(x^,xi)yiP(\hat{y}|\hat{x},S) = \sum^{k}_{i=1}a(\hat{x},x_i)y_i

在測試過程中,給定一個新的支撐集SS',我們可以用之前學到的模型對每個測試樣本x^\hat{x}得到他們可能的label y^\hat{y}

4 Experimental results

4.1 Omniglot dataset

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-pzZgjWu5-1583568374737)(C:%5CUsers%5C14186%5CAppData%5CRoaming%5CTypora%5Ctypora-user-images%5Cimage-20200306191821654.png)]
Omniglot 數據集包含來自 50個不同國家的字母表的 1623 個不同手寫字符。每一個字符都是由 20個不同的人通過亞馬遜的 Mechanical Turk 在線繪製的。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-JivqVe9V-1583568113112)(C:%5CUsers%5C14186%5CAppData%5CRoaming%5CTypora%5Ctypora-user-images%5Cimage-20200306191322810.png)]

4.2 ImageNet dataset

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-t3tZ3n9r-1583568113112)(C:%5CUsers%5C14186%5CAppData%5CRoaming%5CTypora%5Ctypora-user-images%5Cimage-20200306191849333.png)]
作者一共在ImageNet數據集上做了三組實驗:

  • In the rand setup:在訓練集中隨機去除了118個label的樣本,並將這118個標籤的樣本用於之後的測試。
  • For the dogs setup:移除了所有屬於狗這一大類的樣本(一共118個子類),之後用這118個狗的子類樣本做測試.
  • 作者還新定義了一個數據集 miniImageNet —— 一共有100個類別,每個類有600個樣本。其中80個類用於訓練20個類用於測試。

實驗結果爲:
在這裏插入圖片描述

4.3 Penn Treebank dataset

這個Pennn Treebank數據集來自華爾街日報。作者利用數據集做了一個one-shot Language Model的實驗,利用上下文來預測中間詞。作者通過query集中與support集中兩個句子的比較來確定中間詞。如下圖所示。
在這裏插入圖片描述
但是實驗結果並不理想。

The LSTM language model oracle achieved an upper bound of 72.8% accuracy on the test set. Matching Networks with a simple encoding model achieve 32.4%, 36.1%, 38.2% accuracy on the task with k = 1, 2, 3 examples in the set, respectively. Future work should explore combining parametric models such as an LSTM-LM with non-parametric components such as the Matching Networks explored here.

作者只是在這裏提供了一個Matching Network用於語言模型的思路。

5 Conclusions

作者在本文中引入了Matching Networks並在小樣本學習任務上取得了很不錯的效果。作者還在ImageNert數據集上定義了一個one-shot任務,此後ImageNet數據集成爲了Meta-Learning的標準數據集。同時作者啓發性地將one-shot任務應用於語言模型,爲後續研究提供了一個很好的思路。

References

[1] 知乎文章
[2] github 文章
[3] 論文原文

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章