文章目錄
前言
我將看過的增量學習論文建了一個github庫,方便各位閱讀地址
主要工作
類別不平衡導致增量學習出現災難性遺忘,本論文設計了一種loss函數,以抵抗類別不平衡造成的負面影響。
算法介紹
論文將類別不平衡對增量學習的影響分爲三個部分
- Imbalanced Magnitudes:新類別權重向量的模大於舊類別,如上節所示
- Deviation:出現災難性遺忘
- Ambiguities:新類別的權重向量與舊類別相似,模型容易將舊類別數據劃分爲新類別
爲了解決上述問題,論文通過如下三個步驟來構建最終的loss函數,以消除類別不平衡造成的影響。
- Cosine Normalization(抵抗Imbalanced Magnitudes)
- Less-Forget Constraint(抵抗Deviation)
- Inter-Class Separation(抵抗Ambiguities)
對應關係如下圖:
符號約定
符號名 | 含義 |
---|---|
特徵提取器的輸出 | |
特徵提取器的輸出L2歸一化後的結果 | |
全連接層分類器中,第類對應的1*n維權重向量 | |
全連接層分類器中,第類對應的1*n維權重向量L2歸一化後的結果 | |
全連接層分類器中,第類對應的偏置 | |
第類的概率 | |
Cosine Normalization
在CIFAR100上使用iCaRL,分類器權重的L2範式以及偏置()值可視化的結果如下圖:
從上圖至少可知,類別不平衡會導致分類器出現兩個問題
- 新類別權重向量的L2範式大於舊類別權重向量
- 新類別的偏置(參數)基本大於0,舊類別的偏置(參數)基本小於0
上述兩個問題可能導致分類器出現分類偏好
個人疑問
實驗一:在Large Scale Incremental Learning一文中,去除掉分類器的偏置項(參數)後,分類器的準確率有所上升
實驗二:去除上述兩個影響後,分類器的準確率有所提升(請查看Ablation Study部分)。
上述兩個實驗,都是給出準確率,但是抵抗分類偏好,不應該給出混淆矩陣嗎?
回答
一個簡單的步驟,例如去除偏置項、L2歸一化只是在一定程度上抵抗分類偏好,其混淆矩陣仍可能顯示分類器有分類偏好。採取某些步驟後,模型的準確率大幅上升,意味着誤分爲新類別的數據被分類器正確分類,在一定程度上說明該步驟可以抵抗分類偏好
爲了解決上述兩個問題,論文做了兩個工作
- 對每個類別的權重向量使用L2歸一化,這樣所有類別的權重向量的L2範式均爲1
- 去除偏置
如果將特徵提取器的輸出也進行L2歸一化,經過softmax層處理後的結果如下:
是一個可學習參數,其存在對於分類而言意義不大(所有值都放大或是縮小相同倍數,大小關係不變),論文對其解釋是用來控制softmax分佈的峯度,可能與優化有關,個人認爲這個參數沒有深入瞭解的必要,因此不在此做過多解釋
爲什麼要對特徵提取器的輸出進行L2歸一化呢?
此時特徵提取器的輸出向量與類權重向量都位於一個高維球體內部,但論文並沒有解釋這樣做有什麼好處,由於特徵提取器進行L2正則化有助於模型收斂,這裏這麼做可能是爲了加速模型收斂
Less-Forget Constraint
按國際慣例,一篇增量學習論文必然會對loss函數進行魔改,本論文自然不能免俗
論文凍結了全連接層分類器舊類別分支的權重向量,定義的知識蒸餾loss如下:
與表示增量學習前後特徵提取器L2歸一化後的輸出,由於進行了L2歸一化,與的模爲1,當上式取值爲0時,意味着兩個向量的夾角爲0,則有,由於全連接層舊類別分支的權重向量被凍結,此時對於舊類別數據,增量學習前後模型的輸出一致(新類別分支的輸出會爲0)。
作者認爲全連接層分類器中的權重在一定程度上反映了類與類之間的關係,因此一個
自然的想法就是固定舊類別分支的權重向量(從而保留類與類之間的關係),讓訓練後的特徵提取器儘可能與訓練前的一致,從而抵抗災難性遺忘。
Inter-Class Separation
爲了預防模型將新舊類別混淆,論文定義瞭如下loss函數:
選出新類別中,輸出()值與舊類別輸出值最接近的個分支,計算其差距,只要差距大於,損失函數的值即爲0,對於舊類別數據,隨着優化的進行,舊類別分支的輸出與新類別分支的輸出差距會逐漸拉大,從而防止將舊類別數據劃分爲新類別數據
需注意,舊類別的權重向量是固定的,上式中,是固定的
損失函數
即爲交叉熵損失函數,表示訓練數據,表示訓練數據中的舊類別數據,是是一個自適應參數,其取值爲
表示舊類別與新類別的數目,是一個自定義大小的參數
疑問
由於每次需要學習的新類別數目是固定的,即固定,不斷提高,會導致下降,即distillation loss在損失函數中的佔比下降,這有點奇怪,隨着增量學習步驟的增多,distillation loss在損失函數中的佔比應該增加纔對。
實驗
baseline | 解釋 |
---|---|
iCaRL-CNN | 用examplar+distillation loss訓練CNN |
iCaRL-NME | 用examplar+distillation loss訓練CNN,分類器採用nearest- mean-of-exemplars(最近鄰) |
Ours-CNN | examplar+上述損失函數訓練CNN |
Ours-NME | examplar+上述損失函數訓練CNN,分類器採用nearest- mean-of-exemplars(最近鄰) |
joint-CNN | 用全部數據訓練CNN |
CIFAR100、ImageNet-Subset、ImageNet-Full上的結果
比較有意思的是Ours-CNN與Ours-NME差距不大,兩者只是採用的分類器不同,NME並不會出現分類偏好的情況,這在一定程度上說明,使用論文提出的損失函數進行增量學習,可以讓分類器抵抗分類偏好
按國際慣例,應該給出混淆矩陣進一步說明抵抗分類偏好,如下
Ablation Study
符號約定
- CN:Cosine Normalization
- LS:Less-Forget Constraint
- IS:Inter-Class Separation
- AW:自適應參數,即式1
每進行完一次增量學習,都會使用類別平衡的數據(examplar+新類別部分數據)對模型進行finetuning(這個操作可以查看End-to-End Incremental Learning)
CN、LS、IS的影響
上圖可以看出損失函數每個部分對於準確率提升的效果,說明三者缺一不可,上圖中的Ours-CNN使用了AW,未使用CBF,其他模型都使用了CBF,可以看出,CBF對於模型的準確率的影響不大,說明應用本論文提出的方法,分類器分類偏好已經被較好解決
AW的影響
所有實驗數據都是進行多次實驗取平均
個人理解
爲什麼增量學習的CNN比非增量學習的CNN準確率低?
答案是災難性遺忘,但是造成災難性遺忘的核心原因,個人覺得還是類別不平衡,類別不平衡會導致分類器出現分類偏好(更偏向於新類別,因爲新類別的訓練數據多),因此,目前閱讀過的大部分論文都是針對分類器入手。
想要提高增量學習的分類準確率,首要解決的是類別不平衡問題帶來的負面影響,
但是即使類別不平衡問題可以較好的解決,模型的分類準確率爲什麼無法達到非增量學習分類模型的準確率呢?