Few-Shot Learning with Global Class Representations筆記整理
1 Introduction
在小樣本學習(Few-Shot Learning, FSL)問題上,對於base classes中的每個類別,我們往往有充足的訓練數據;對於那些novel classes中的每個類別,我們只有少量的帶標籤的數據。FSL旨在利用base calsses中大量的數據,來學習出一個可以對novel classes中的數據標籤準確辨別的模型。
注:base class和novel class是本文作者自己創造的詞彙,我沒有想到很好的翻譯方法。在文章裏,base class指擁有充足樣本的類別(用於訓練);novel class指的是那些只有少量樣本的類(用於測試)。
現在解決小樣本學習問題一般都使用元學習的方法,但是元學習的做法也有一定的侷限性,因爲它們往往只使用了源數據(source data),但是對於目標數據卻幾乎沒有使用,即使在經歷過fine-tuning階段,也無法保證能學習到滿足目標數據需求的模型。(比如,要辨別一個動物是不是貓,但現在手頭上只有5張貓的照片以及大量狗,獅子,鳥的照片。這個時候元學習的一般做法是先在狗,獅子和鳥的照片上進行訓練,訓練好後再用5張貓的照片來進行微調。)
而作者在本文提出的方法同時使用5張貓(novel classes)的照片和大量狗,獅子,鳥(base classes)的照片來進行訓練,作者把這稱爲全局表徵(global class representations)。
因爲將novel class在的少量數據和base class中的大量數據一起訓練的話,勢必會有樣本不平衡的問題,作者使用兩種方法來解決這一問題:
- 合成novel class的新樣本;
- 引入片段訓練(episodic training)。
2 Contributions
- 提出將base classes 和novel classes同時作爲全局表徵來進行小樣本學習的訓練;
3 Method
在這一節將首先介紹本方法的兩個模塊:表徵註冊模塊和樣本合成模塊。然後再介紹如何將這兩個模塊合併起來,最後介紹如何將此方法拓展到生成式FSL的設定中(generalized FSL)。其中使用表示一個樣本經過特徵抽取器F之後得到的視覺特徵。
3.1 樣本合成模塊
本模塊用於解決類別不平衡問題,共分爲兩步:第一步用原始樣本生成新的樣本,第二步用第一步獲取的所有樣本合成一個新樣本。
首先對novel classes在的原始樣本使用random cropping, random fipping和data hallucination操作(這三個方法出自論文:Low-shot learning from imaginary data.)來爲每個novel class生成個樣本。
對於一個novel class,作者先從中隨機挑選出個樣本,具體操作如下:
其中,是平均分佈。
對於一個novel class,再從平均分佈中選出個值,將這個值作爲權重,對個樣本的視覺特徵求加權和。於是可以得到一個新的樣本,具體操作如下:
通過這種方法就可以擴大類內差異了,因此少量數據的問題就得到了緩解。
3.2 註冊模塊
給定全局類表徵和視覺特徵。我們將註冊模塊簡記爲R。針對每個視覺特徵,R將會生成一個N維向量,其中第個元素是和類別的全局表徵之間的相似度分數。具體地:
其中,和分別爲樣本的視覺特徵和全局詞表徵的embeddings。
因此,我們可以設計一個註冊損失,其中爲樣本標籤,CE爲交叉熵損失。通過這個損失函數就可以使這個樣本在embedding空間儘可能的靠近這個這個類的全局類表徵。
通過將樣本與其類別的全局表徵在embedding空間進行比較,R模塊可以使得該類的全局表徵更靠近本類的樣本,更疏遠其他類的樣本。(此處可以參考聚類的思想)
大致的流程可以參考下圖:
需要注意的是,圖中的這些點都存在於embedding空間中。本方法的視覺特徵和全局類表徵都是需要經過訓練和優化的。訓練好後會使用全局類表徵向量取比對query集中的樣本,從而確定它屬於哪一類。
3.3 Few Shot Learning By Registration
是所有類的結合,包括:base classes和novel classes。有人會好奇本文一開始的全局類表徵是怎麼得到的?其實很簡單。我們可以通過對一類中所有樣本的視覺特徵取平均來獲得一個初始的全局類表徵。本文模型的目標就是爲每一個novel class學習出一個全局類表徵。
除了使用數據合成策略來緩解數據不平衡問題,作者還引入了元學習中常用的片段學習策略。簡而言之片段學習就是一次性採樣多個類中一定數量的樣本進行訓練,這些類的集合就是所謂的片段或批量(episode/mimi-batch)。
但是novel class中的樣本數量一般小於一個片段中所要求的樣本數(),比如進行5-way-5-shot實驗,support樣本數爲5,query樣本數爲15,但是novel class只有5個樣本。此時就需要使用樣本合成模塊對樣本進行擴展了。
在進行片段學習時,我們首先將採樣出的圖片輸入特徵抽取器,從而生成相應的視覺特徵。然後我們依據採樣出的support set中的數據爲每一類構造出相應的片段表徵。需要注意的是這裏的片段表徵是一種局部表徵(相對於全局表徵)。
對於base classes,我們只需對每個類中樣本的視覺特徵取平均就可得到片段表徵;對於novel classes,我們需要利用樣本合成模塊來爲每類合成一個新樣本,這個新樣本就是這個novel class的片段表徵。
然後我們將片段表徵(基於一個episode的樣本)和全局表徵(基於全部樣本)輸入註冊模塊R,從而計算出他們的相似度。類似的作者定義了一個片段表徵的損失函數:
其中表示片段表徵和所有全局表徵之間的相似度。
我們將依據這裏的相似度來選擇全局類表徵,並依據全局表徵通過最近鄰法來對query數據集中的樣本進行分類。相似地,分類損失的計算方法爲:
其中指的是選擇出的(用argmax方法進行選擇)全局表徵和query樣本之間的相似度。
將註冊損失和分類損失進行合併在一起就得到了最終的損失函數:
我們將依據這個損失函數來更新全局表徵,註冊模塊中的參數和特徵抽取器中的參數。
4 實驗和討論
模型架構
特徵抽取器:4個卷積塊,每個卷積塊包含64個的卷積層,1個batch normalization層,一個ReLU層和一個的最大池化層。
5 結論
本文提出了一種利用全局類表徵來解決小樣本學習問題的方法,同時此方法可以輕易地拓展到生成式小樣本學習中。