深度學習論文筆記(增量學習)——End-to-End Incremental Learning

前言

我將看過的增量學習論文建了一個github庫,方便各位閱讀地址

主要工作

論文提出了一種算法,以解決增量學習中的災難性遺忘問題,與iCaRL將特徵提取器的學習與分類器分開不同,本論文提出的算法通過引入新定義的loss以及finetuning過程,在有效抵抗災難性遺忘的前提下,允許特徵提取器與分類器同時學習。

本論文提出的方法需要examplarexamplar


算法介紹


總體流程

在這裏插入圖片描述
總體分爲四個流程

  1. 構建訓練數據
  2. 模型訓練
  3. finetuning
  4. 管理 examplarexamplar

步驟一:構建訓練數據

訓練數據由新類別數據與examplar構成。

設有nn箇舊類別,mm個新類別,每個訓練數據都有兩個標籤,第ii個訓練數據的標籤爲

  1. 使用onehot編碼的1(m+n)1*(m+n)的向量pip_i
  2. 設舊模型爲Ft1F_{t-1},訓練數據爲xxqi=Ft1(x)q_i=F_{t-1}(x)qiq_i爲一個1n1*n維的向量

步驟二:模型訓練

模型可以選用常見的CNN網絡,例如ResNet32等,按照國際慣例,這一節會介紹distillation loss,作爲一篇被頂會接收的論文,自然不能免俗


loss函數介紹

符號約定

符號名 含義
NN NN個訓練數據
pip_i 含義查看上一節
qiq_i 含義查看上一節
q^i\hat q_i 新模型舊類別分支的輸出,爲一個1n1*n的向量
nn 舊類別分支
mm 新類別分支
oio_i 新模型對於第ii個數據的輸出,爲一個(n+m)1(n+m)*1的向量

Classification loss即交叉熵,如下:

LC(w)=1Ni=1Nj=1n+mpijsoftmax(oij)L_C(w)=-\frac{1}{N}\sum_{i=1}^N\sum_{j=1}^{n+m}p_{ij}*softmax(o_{ij})

其中
softmax(oij)=eoijj=1n+meoijsoftmax(o_{ij})=\frac{e^{o_{ij}}}{\sum_{j=1}^{n+m}e^{o_{ij}}}


distillation loss的形式如下

LD(w)=1Ni=1Nj=1npdistijqdistijL_D(w)=-\frac{1}{N}\sum_{i=1}^{N}\sum_{j=1}^{n}pdist_{ij}qdist_{ij}

其中
pdistij=eq^ijtj=1neq^ijtqdistij=eqijtj=1neqijt pdist_{ij}=\frac{e^{\frac{\hat q_{ij}}{t}}}{\sum_{j=1}^{n}e^{\frac{\hat q_{ij}}{t}}}\\ qdist_{ij}=\frac{e^{\frac{q_{ij}}{t}}}{\sum_{j=1}^{n}e^{\frac{q_{ij}}{t}}}

LD(w)L_D(w)即讓模型儘可能的記住舊類別的輸出分佈。t是一個超參數,在本論文中,t=2t=2


個人疑問

distillation loss的作用是讓模型記住以往學習到的規律,相當於側面引入了舊數據集,從而抵抗類別遺忘。

直覺上來說,distillation loss應該只對舊類別數據進行計算,但是新類別數據的舊類別分支輸出仍用於計算distillation loss,論文對此給出的解釋是“To reinforce the old knowledge”

我認爲這種做法的出發點爲:舊模型對於新類別數據的輸出(經softmax處理),也是一種舊知識,也需要防止遺忘,因此,新模型對於新類別數據的舊類別輸出(經softmax處理),與舊模型對於新類別數據的輸出(經softmax處理)也要儘可能一致


步驟三:finetuning

使用herding selection算法,從新類別數據中抽取部分數據,構成與舊類別examplar大小相等的數據集,此時各類別數據之間類別平衡,利用該數據集,在小學習率下對模型進行微調,選用的loss函數應該是交叉熵。

步驟二使用類別不平衡的數據訓練模型,會導致分類器出現分類偏好,finetuning可以在一定程度上矯正分類器的分類偏好


步驟四:管理examplarexamplar

論文給出了兩類方法

  1. Fixed number of samples:沒有內存上限,每個類別都有MM個數據
  2. Fixed memory size:內存上限爲KK

使用herding selection算法選擇新類別數據,構成新類別的examplarexamplar


實驗

論文訓練模型使用了數據增強,具體方式如下:
在這裏插入圖片描述
每個實驗都進行了五次訓練,取平均準確率
實驗過程沒有太多有趣的地方,在此不做過多說明

Fixed memory size

在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述


Fixed number of samples

在CIFAR100上的結果如下
在這裏插入圖片描述
img/cls表示每個examplar中圖片的個數


Ablation studies

首先是選擇數據構建examplar的方法,論文比對了三類方法

  1. herding selection:平均準確率63.6%
  2. random selection:平均準確率63.1%
  3. histogram selection:平均準確率59.1%

上述三個選擇方法的解釋如下:
在這裏插入圖片描述
接下來論文比對了算法各部分對準確率提升的貢獻
在這裏插入圖片描述
上述模型的解釋如下
在這裏插入圖片描述

個人理解

災難性遺忘的本質是類別不平衡,模型在學習舊類別時,所使用的數據是充分的,引入知識蒸餾loss,就是儘可能保留舊數據上學習到的規律,在訓練時,相當於側面引入了舊數據。

論文在distillation loss的基礎上又引入了類別平衡條件下的finetuning,相當於進一步抵抗增量學習下類別不平衡的導致的分類器偏好問題,由此取得模型性能的提升。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章