前言
在大規模小樣本學習(large-scale FSL)中,有這樣一個baseline:使用所有的源類(source class)訓練一個feature embedding模型,然後用這個模型提取出目標類(target class)中樣本的特徵,以進行最近鄰分類。從下圖可以看出,僅使用簡單的最近鄰(NN)方法得到的結果,甚至能與目前最先進的FSL模型相匹配:
這就說明了一個問題:在SGM、PPA和LSD這些FSL模型中的知識遷移基本上都是通過feature embedding模型提取的可轉移特徵(transferable feature)實現的。也就是說,只有當這些可轉移特徵能夠更好的表示目標類的樣本時,大規模小樣本學習的性能纔會變得更好。
由此,本文提出了一種新的FSL模型,如下圖所示,通過讓源類和目標類共享一個類層次結構(class hierarchy)來學習更具有可轉移性的feature embedding模型。主要思想是:將源類和目標類之間的語義關係作爲先驗知識,進而學習feature embedding模型以實現對目標類樣本的識別。
這種語義關係被編碼爲類似於樹的類層次結構,這樣的一個樹能夠覆蓋所有現存的目標類別。源類和目標類在樹中都是作爲葉子節點,在語義上相似的類被分組,然後每個簇在上一層形成一個父節點(即超類節點)。
爲了利用類層次結構中的先驗知識,本文還提出了層次預測網絡(hierarchy prediction network),它可以將類層次結構編碼到分類過程中:
- 在訓練時,源類的樣本被送入一個CNN中,後跟層次預測網絡。由於源類和目標類肯定會存在一些共同的超類,因此這個層次預測網絡就能學習到一些可轉移特徵用於FSL;
- 在測試時,從目標類中採樣兩種樣本:few-shot樣本和test樣本,需要從這兩種樣本中提取出視覺特徵,然後使用few-shot樣本的視覺特徵作爲參考,使用最近鄰方法來識別test樣本。
模型設計
首先來定義大規模FSL問題。設表示源類集,表示目標類集,這兩個類集是不重疊的,即。然後從中採樣訓練集,從中採樣few-shot樣本集和測試集。中每個類都有充足的帶標記的樣本,而中每個類只有少量帶標記的樣本。那麼大規模FSL的目標就是在上獲得好的分類結果。
本文的大規模FSL方法包括兩個階段:
- 學習可轉移的視覺特徵;
- 識別中的樣本標籤。
1. 特徵學習
本文提出了一種可轉移特徵學習模型,在這個模型中,首先構建樹狀類層次結構,以編碼源類和目標類之間的語義關係;然後利用層次預測網絡將類層次結構中的先驗知識整合起來,從而爲大規模FSL學習可轉移的視覺特徵。使用和語義關係來訓練這個模型。
類層次結構
首先來說一下類層次結構的這個樹是怎麼構建的,
- 首先,將每個源類和目標類的類名錶示爲詞向量(word vector),它們都作爲樹的葉子節點,構成類層次結構中最底部的類層;
- 從葉子節點開始,將lower layer中的詞向量進行聚類,得到upper layer中的節點,將每個聚類中心作爲upper layer中的父節點(即超類節點),該父節點的詞向量是通過對其子類的詞向量做平均得到的。同一層的超類節點構成了一個超類層。
這樣就可以得到類層次結構的一個樹,共包含個超類層和一個類層,設表示類層,表示個超類層。如下圖所示:
層次預測網絡
在獲得類層次結構之後,接下來就是如何利用它了,也就是本文提出的模型,這個模型是通過使用層次預測網絡擴展CNN得到的,如下圖所示:
層次預測網絡(圖中的紫色框)預測的是超類的標籤,預測過程可分爲兩步:
- 首先在不同的類/超類層上預測標籤。在這一步中,在CNN的頂上添加個不共享的全連接(FC)網絡,其中包括softmax層,給定一個目標樣本,每個FC網絡在相應的層上預測類/超類的概率分佈;
- 然後將類/超類層的層次結構編碼爲超類標籤預測。也就是說,通過將第一步中獲得的當前層和較低層的預測結果組合起來,來推斷每層的超類標籤。在這一步中,通過使用個不共享的FC網絡來進行編碼,每個FC網絡推斷出相應層的超類標籤。具體來說就是,對於最底部的超類層(層)對應的FC網絡來說,將第一步中後兩層(和)的輸出作爲該FC的輸入,FC的輸出就是的最終預測結果:
其中表示第一步中的FC的輸出,也就是類層的預測結果;表示第一步中的FC的輸出,也就是最底部的超類層的預測結果。是第二步中的FC操作;最後的輸出表示上對所有可能超類標籤的預測分佈。
那麼,也可以得到出層的FC輸出的預測結果:
和分別表示在第一步和第二步中的FC網絡,是第一步的輸出,是第二步的輸出。
給定輸入圖像,損失函數爲:
2. 標籤推理
一旦通過訓練得到了特徵學習模型,那麼接下來就可以利用這個模型提取和中圖像樣本的特徵。利用這些視覺特徵,就可以直接使用最近鄰方法推斷中樣本的標籤。具體來說就是,計算中每類樣本的視覺特徵的平均,然後給定中的某個樣本,計算它與每個類平均的餘弦距離,距離最小的類標籤就作爲這個test樣本的標籤。