核心思想
本文提出一種基於數據增強的小樣本學習算法(ICI)。本文的數據增強是通過自訓練(self-training)的方式實現的,具體而言就是利用有標籤的樣本先訓練得到一個分類器,然後預測無標籤樣本,得到僞標籤。選擇僞標籤中置信度較高的樣本,補充到訓練集中,實現數據擴充。通過迭代訓練的方式逐步改善分類器的效果。網絡流程如下圖所示
首先利用有標籤樣本訓練特徵提取器和線性分類器,然後無標籤的樣本經過特徵提取和簡單的線性分類後得到預測的僞標籤,利用實例置信度推斷模塊(Instance Credibility Inference,ICI)選擇出置信度較高的樣本和僞標籤,並利用其擴充支持集,而置信度較低的樣本則用於更新無標籤數據集。整個過程中最重要的一點就是如何計算預測得到的僞標籤的置信度,從而避免將分類錯誤的樣本補充到支持集中,導致數據集被污染。下面具體介紹ICI模塊的處理過程,無論對於有標籤樣本還是無標籤樣本,網絡的預測結果計算方式如下
式中表示樣本對應的特徵向量(特徵提取網絡輸出的特徵向量經過PCA降維後得到),表示分類器的係數矩陣,表示均值爲0,方差爲的高斯噪聲,用於修正實例被分配給類別的概率,的模越大,實例被分配給類別的難度越大。那麼本文的優化目標爲
式中表示懲罰項,表示懲罰項係數。爲求解上述目標,本文的損失函數如下
令可得
式中表示廣義逆矩陣。但值得注意的是,本文希望用來度量實例的置信度,而不是用,這是因爲簡單的線性分類器不足以對各種類別的樣本進行很好的分類,而且的值本身也依賴於的取值。因此我們將上式代入損失函數中得到下式
式中。令,則上式可簡化爲
利用塊下降算法可以求解上式。首先存在一個理論值,使得上式的解均爲0,該理論值如下
那麼我們可以得到由0到之間一系列的,對於每個在求解目標函數時,都能獲得一條對應的規則化路徑。而且當由0變化到時,的稀疏性不斷增強,直到他的所有元素都逐漸消失(vanish)。懲罰項會使得一個實例接一個實例的消失,且消失的越早,則表明該實例的預測結果與真實值越爲接近,因此根據消失的順序可以得到對應的置信度。
實現過程
網絡結構
無具體介紹
損失函數
見上文介紹
訓練策略
訓練和推斷過程如下
創新點
- 通過自訓練的方式獲取未標記樣本的僞標籤,並利用其擴充數據集,達到數據增強的目的
- 設計了一種基於統計學的僞標籤置信度度量方法,選擇出置信度最高的樣本,用於支持數據集的補充
算法評價
本文的整體思想並不複雜,理解難點主要集中在ICI模塊進行置信度度量的方面。本文提出的置信度度量方法是基於統計信息的,根據實驗結果來看其性能提升作用還是比較明顯的,在多個數據集上都取得了SOTA的成績。
如果大家對於深度學習與計算機視覺領域感興趣,希望獲得更多的知識分享與最新的論文解讀,歡迎關注我的個人公衆號“深視”。