核心思想
本文提出一種基於遷移學習的半監督小樣本學習算法(TransMatch)。整個算法並不複雜,首先利用帶有標籤的基礎數據集訓練特徵提取網絡,然後用該特徵提取網絡爲新的數據集初始化分類器權重,最後用半監督學習的方式進一步更新整個網絡。整個流程如下圖所示
第一階段:預訓練階段。這一階段沒有什麼值得介紹的,就是用帶有標籤的基礎數據集對特徵提取器進行訓練。
第二階段:分類器權重“生成”階段(Classifier Weight Imprinting)。在這一階段,使用已經預訓練好的特徵提取網絡,對新的帶有標籤的數據集進行特徵提取,並生成對應的分類器權重。本文采用一種叫做Weight Imprinting的方法來生成分類器的權重,方法如下
式中表示類別對應的分類器權重,表示特徵提取網絡,表示類別中第個樣本。通過上式得到每個類別對應的權重後,再通過計算餘弦距離的方式進行分類
式中對於樣本,分別計算其與個類別權重之間的餘弦相似度,並選擇餘弦相似度最高的哪一類作爲預測結果。
第三階段:半監督微調訓練階段。在這一階段採用新的帶有標籤的數據集和與類別相同但不帶有標籤的數據集,共同對網絡進行微調訓練。本文采用MixMatch的方式進行半監督訓練,定義表示個帶有標籤的樣本,表示個不帶有標籤的樣本。首先對每個無標籤的樣本進行數據擴充(應該採用的是常規的翻轉,放縮等形式)得到個合成樣本,然後用第二階段訓練得到的分類器對每個無標籤樣本進行預測,並取個合成樣本的平均值作爲預測結果
銳化操作(sharpen operation)用於進一步增強預測結果
其中,這樣就得到了無標籤樣本對應的標籤信息了。將數據集級聯後,再將順序打亂,得到新的混合數據集,然後將其分爲以下兩個集合
其中混合操作MixUP計算過程如下
式中,是從Beta分佈中隨機生成的。
實現過程
網絡結構
特徵提取網絡採用寬闊的殘差網絡WRN-28-10。
損失函數
損失函數計算過程如下
其中
訓練策略
本文的訓練過程如下
創新點
- 採用基於遷移學習的半監督訓練方法實現小樣本學習任務
- 採用Weight Imprinting的方式進行分類器權重生成,採用MixUp方式進行半監督訓練
算法評價
與之前研究較多的採用元學習的小樣本學習方法不同,本文沿用了更爲傳統的遷移學習思想,並結合半監督學習方式,證明了遷移學習還是能夠在小樣本場景下取得較好的效果的。但本文核心創新點並不多,有一種拼湊的感覺。無論是Weight Imprinting分類器權重生成還是MixUp半監督訓練方法都是借鑑了別人的方案。
如果大家對於深度學習與計算機視覺領域感興趣,希望獲得更多的知識分享與最新的論文解讀,歡迎關注我的個人公衆號“深視”。