核心思想
該文章採用一種帶有嵌入特徵提取器的最近鄰方法實現了小樣本或單樣本分類任務,其集合了參量化方法和非參量化方法的優勢。普通的最近鄰算法,直接計算測試樣本與各個類別的訓練樣本之間的距離,然後選擇距離最近的K個樣本,按照加權求和或投票法來決定測試樣本所屬的類別(kNN),該方法無需訓練只需要保存所有訓練樣本的特徵信息,屬於非參量化學習算法。普通的神經網絡算法,需要通過訓練更新網絡中的各個權重參數,最終利用softmax函數計算訓練樣本屬於各個類別的概率,屬於參量化學習算法。而對於小樣本或單樣本分類任務,由於訓練樣本極少無法充分的優化網絡權重參數,因此效果通常較差。
本文創新性的提出Matching Network算法結合了兩種算法的優勢,首先從訓練集中選出個類別(例如k=5),再從每個類別中選擇若干個帶有標籤的樣本(例如1或5),這5個或25個樣本構成了支持集(Support Set),除此之外還要在之前的類別中選擇一個帶有標籤的Batch樣本用於訓練網絡。然後利用神經網絡作爲特徵提取器,提取支持集中樣本的特徵信息;再用特徵提取器(通常與相同,採用VGG等網絡結構),提取Batch樣本的特徵信息。再利用餘弦距離函數度量與各個類別的支持樣本對應的之間的距離,並利用softmax函數將距離轉化爲屬於各個類別的概率。最後採用加權求和的方式得到Batch樣本的預測類別。作者稱可以理解爲一種注意力機制,當採用核函數的形式時(如上文采用的softmax函數),該算法相當於核密度估計算法(KDE);當採用0/1函數時,該算法相當遠k最近鄰算法(kNN)。在測試時,需要提供支持集S和測試樣本,此時S中樣本的類別可以是訓練樣本之外的類別,但測試樣本的類別必須包含在此時的S之內。利用訓練好的特徵提取器和分別提取特徵向量,並計算測試樣本與支持集S中各個樣本之間的距離(相似性)作爲權重,最終採用加權求和的方式得到預測類別。這類似於拿着測試樣本與支持集中的各個樣本進行匹配,尋找最相近的那個樣本,這也是取名爲Matching Network的原因。
在此基礎上,作者提出在提取特徵時如果孤立地考慮單獨一個樣本是短視的,應該綜合考慮整個支持集中的樣本,因此得到的特徵提取器不再是和,而是和,作者稱之爲FCE(Full Context Embeddings)。作者引入了當時最新的長短期記憶網絡LSTM,將支持集S中的各個樣本看做一個序列,輸入到網絡中得到
其中和就是由普通的神經網絡構成的特徵提取器,表示LSTM的訓練次數。而對於則是採用了雙向LSTM算法來計算。作者指出這種方式能夠使網絡有選擇的記住或遺忘一部分信息,並且增加了計算注意力權重的“深度”。
實現過程
網絡結構
如上文所述,特徵提取器可採用常見的VGG或Inception網絡,作者設計了一種簡單的四級網絡結構用於圖像分類任務的特徵提取,每級網絡由一個64通道的3 * 3卷積層,一個批規範化層,一個ReLU激活層和一個2 * 2的最大池化層構成。然後將最後一層輸出的特徵輸入到LSTM網絡中得到最終的特徵映射和。
損失函數
論文中並沒有明確的說明具體的損失函數是什麼,只是提到訓練的目標是最小化基於支持集S的Batch樣本B的預測誤差,可以表示爲對數最大似然的形式
訓練策略
訓練過程在上文已經介紹過,簡單來講就是每次迭代包含多個任務Task,每個任務中包含一個支持集S和一個Batch樣本B,每個支持集中包含多個類別的樣本,其中有且只有一種與樣本B同類。作者還提到測試條件與訓練條件必須匹配,這裏並不是很明確,Protypical Network中認爲是指支持集中的類別數量和樣本數量應該保持相同。此外作者指出訓練樣本和測試樣本可以選擇不同的類別,比如在不包含狗的數據集上訓練得到的網絡仍然可以用於狗這種類別的分類任務,這也是非參量化算法所具備的優勢,網絡可以很容易的遷移到其他任務中。當然,必須說明的是如果訓練樣本的類別和測試樣本的類別差距很大時,該算法也無法起效了。
推廣應用
作者不僅將該算法應用於圖像分類任務,還推廣到語言任務中用於填補句子中缺失的一個單詞,這展示了該算法強大的遷移能力,可以把特徵提取器替換成其他的形式,以滿足新的任務需求。
創新點
- 創新性的採用匹配的形式實現小樣本分類任務,引入最近鄰算法的思想解決了深度學習算法在小樣本的條件下無法充分優化參數而導致的過擬合問題,且利用帶有注意力機制和記憶模塊的神經網絡解決了普通最近鄰算法過度依賴度量函數的問題,將樣本的特徵信息映射到更高維度更抽象的特徵空間中。
- 新型的訓練策略,一個訓練任務中包含支持集和Batch樣本
算法評價
該算法提出了一種全新的小樣本分類任務的解決方案,因此得到了廣泛的關注。其分類效果較好,遷移能力強大(體現在兩個方面:1.在某個訓練集上得到的網絡可以用於其他類別圖像分類;2. 通過替換特徵提取網絡可以應用到不同的分類任務中)。但同時他也存在一些問題,如受到非參量化算法的限制,隨着支持集S的增長,每次迭代的計算量也會隨之快速增長,導致計算速度降低。此外,在測試時必須提供包含目標樣本類別在內的支持集,否則他只能從支持集所包含的類別中選擇最爲接近的一個輸出其類別,而不能輸出正確的類別。