【知識蒸餾】Deep Mutual Learning

【GiantPandaCV導語】Deep Mutual Learning是Knowledge Distillation的外延,經過測試(代碼來自Knowledge-Distillation-Zoo), Deep Mutual Learning性能確實超出了原始KD很多,所以本文分析這篇CVPR2018年被接受的論文。同時PPOCRv2中也提到了DML,並提出了CML,取得效果顯著。

引言

首先感謝:https://github.com/AberHu/Knowledge-Distillation-Zoo

筆者在這個基礎上進行測試,測試了在CIFAR10數據集上的結果。

學生網絡resnet20:92.29% 教師網絡resnet110:94.31%

這裏只展示幾個感興趣的算法結果帶來的收益:

  • logits(mimic learning via regressing logits): + 0.78

  • ST(soft target): + 0.16

  • OFD(overhaul of feature distillation): +0.45

  • AT(attention transfer): +0.71

  • NST(neural selective transfer): +0.38

  • RKD(relational knowledge distillation): +0.65

  • AFD(attention feature distillation): +0.18

  • DML(deep mutual learning): + 2.24 (ps: 這裏教師網絡已經訓練好了,與DML不同)

DML也是傳統知識蒸餾的擴展,其目標也是將大型模型壓縮爲小的模型。但是不同於傳統知識蒸餾的單向蒸餾(教師→學生),DML認爲可以讓學生互相學習(雙向蒸餾),在整個訓練的過程中互相學習,通過這種方式可以提升模型的性能。

DML通過實驗證明在沒有先驗強大的教師網絡的情況下,僅通過學生網絡之間的互相學習也可以超過傳統的KD。

如果傳統的知識蒸餾是由教師網絡指導學生網絡,那麼DML就是讓兩個學生互幫互助,互相學習。

DML

小型的網絡通常有與大網絡相同的表示能力,但是訓練起來比大網絡更加困難。那麼先訓練一個大型的網絡,然後通過使用模型剪枝、知識蒸餾等方法就可以讓小型模型的性能提升,甚至超過大型模型。

以知識蒸餾爲例,通常需要先訓練一個大而寬的教師網絡,然後讓小的學生網絡來模仿教師網絡。通過這種方式相比直接從hard label學習,可以降低學習的難度,這樣學生網絡甚至可以比教師網絡更強。

Deep Mutual Learning則是讓兩個小的學生網絡同時學習,對於每個單獨的網絡來說,會有針對hard label的分類損失函數,還有模仿另外的學生網絡的損失函數,用於對齊學生網絡的類別後驗。

這種方式一般會產生這樣的疑問,兩個隨機初始化的學生網絡最初階段性能都很差的情況,這樣相互模仿可能會導致性能更差,或者性能停滯不前(the blind lead the blind)。

文章中這樣進行解釋:

  • 每個學生主要是倍傳統的有監督學習損失函數影響,這意味着學生網絡的性能大體會是增長趨勢,這意味着他們的表現通常會提高,他們不能作爲一個羣體任意地漂移到羣體思維。(原文: they cannot drift arbitrarily into groupthink as a cohort.)

  • 在監督信號下,所有的網絡都會朝着預測正確label的方向發展,但是不同的網絡在初始化值不同,他們會學到不同的表徵,因此他們對下一類最有可能的概率的估計是不同的。

  • 在Mutual Learning中,學生羣體可以有效彙集下一個最後可能的類別估計,爲每個訓練實例找到最有可能的類別,同時根據他們互學習對象增加每個學生的後驗熵,有助於網絡收斂到更平坦的極小值,從而帶來更好的泛華能力和魯棒性。

  • Why Deep Nets Generalise 有關網絡泛化性能的討論認爲:在深度神經網絡中,有很多解法(參數組合)可以使得訓練錯誤爲0,其中一些在比較loss landscape平坦處參數可以比其他narrow位置的泛華性能更好,所以小的干擾不會徹底改變預測的效果;

  • DML通過實驗發現:(1)訓練過程損失可以接近於0 。(2)在擾動下對loss的變動接受能力更強。(3)給出的class置信度不會過於高。總體來說就是:DML並沒有幫助我們找到更好的訓練損失最小值,而是幫助我們找到更廣泛/更穩健的最小值,更好地對測試數據進行泛華。

DML具有的特點是:

  • 適合於各種網絡架構,由大小網絡混合組成的異構的網絡也可以進行相互學習(因爲只學習logits)

  • 效能會隨着隊列中網絡數量的增加而增加,即互學習對象增多的時候,性能會有一定的提升。

  • 有利於半監督學習,因爲其在標記和未標記數據上都激活了模仿損失。

  • 雖然DML的重點是得到某一個有效的網絡,整個隊列中的網絡可以作爲模型集成的對象進行集成。

DML中使用到了KL Divergence衡量兩者之間的差距:

\[D_{K L}\left(\boldsymbol{p}_{2} \| \boldsymbol{p}_{1}\right)=\sum_{i=1}^{N} \sum_{m=1}^{M} p_{2}^{m}\left(\boldsymbol{x}_{i}\right) \log \frac{p_{2}^{m}\left(\boldsymbol{x}_{i}\right)}{p_{1}^{m}\left(\boldsymbol{x}_{i}\right)} \]

P1和P2代表兩者的邏輯層輸出,那麼對於每個網絡來說,他們需要學習的損失函數爲:

\[\begin{aligned} &L_{\Theta_{1}}=L_{C_{1}}+D_{K L}\left(\boldsymbol{p}_{2} \| \boldsymbol{p}_{1}\right) \\ &L_{\Theta_{2}}=L_{C_{2}}+D_{K L}\left(\boldsymbol{p}_{1} \| \boldsymbol{p}_{2}\right) \end{aligned} \]

其中\(L_{C_{1}},L_{C_{2}}\)代表傳統的分類損失函數,比如交叉熵損失函數。

可以發現KL divergence是非對稱的,那麼對兩個網絡來說,學習到的會有所不同,所以可以使用堆成的Jensen-Shannon Divergence Loss作爲替代:

\[\frac{1}{2}\left(D_{K L}\left(\boldsymbol{p}_{1} \| \boldsymbol{p}_{2}\right)+D_{K L}\left(\boldsymbol{p}_{1} \| \boldsymbol{p}_{2}\right)\right) \]

更新過程的僞代碼:

更多的互學習對象

給定K個互學習網絡,\(\Theta_{1}, \Theta_{2}, \ldots, \Theta_{K}(K \geq 2)\), 那麼目標函數變爲:

\[L_{\Theta_{k}}=L_{C_{k}}+\frac{1}{K-1} \sum_{l=1, l \neq k}^{K} D_{K L}\left(\boldsymbol{p}_{l} \| \boldsymbol{p}_{k}\right) \]

將模仿信息變爲其他互學習網絡的KL divergence的均值。

擴展到半監督學習

在訓練半監督的時候,我們對於有標籤數據只使用交叉熵損失函數,對於所有訓練數據(包括有標籤和無標籤)的計算KL Divergence 損失。

這是因爲KL Divergence loss的計算天然的不需要真實標籤,因此有助於半監督的學習。

實驗結果

幾個網絡的參數情況:

在CIFAR10和CIFAR100上訓練效果

在Reid數據集Market-1501上也進行了測試:

發現互學習目標越多,性能呈上升趨勢:

結論

本文提出了一種簡單而普遍適用的方法來提高深度神經網絡的性能,方法是在一個隊列中通過對等和相互蒸餾進行訓練。

通過這種方法,可以獲得緊湊的網絡,其性能優於那些從強大但靜態的教師中提煉出來的網絡。
DML的一個應用是獲得緊湊、快速和有效的網絡。文章還表明,這種方法也有希望提高大型強大網絡的性能,並且以這種方式訓練的網絡隊列可以作爲一個集成來進一步提高性能。

參考

https://github.com/AberHu/Knowledge-Distillation-Zoo

https://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Deep_Mutual_Learning_CVPR_2018_paper.pdf

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章