淺談“知識蒸餾”技術在機器學習領域的應用

什麼是知識蒸餾技術?

知識蒸餾技術首次出現是在Hinton幾年前的一篇論文《Distilling the Knowledge in a Neural Network》。老大爺這麼大歲數了還孜孜不倦的發明各種人工智能領域新名詞,讓我這種小白有很多可以去學習瞭解的內容,給個贊。

那什麼是知識蒸餾技術呢?知識蒸餾技術的前提是將模型看作一個黑盒,數據進入後經過處理得到輸出。通常意義上,複雜的模型的輸出會比簡單模型準確,那麼是否有辦法讓複雜模型的知識傳遞給簡單模型,就是知識蒸餾要探索的內容。

這就有點類似於遷移學習的原理,在遷移學習中,網絡先學習大數量級的數據,然後生成base模型,再用小數據在base模型的基礎上Fine-Tune。在知識蒸餾中,也是先生成複雜的Teacher模型,然後採用Teacher模型將知識傳遞給簡單的Student模型的方式。

屏幕快照 2020-03-05 下午8.42.41.png

這樣做的好處就是Student網絡不必那麼複雜,某種意義上實現了模型壓縮的功能。

爲什麼叫蒸餾呢?

我最先好奇的其實不是Teacher網絡和Student網絡怎麼傳遞知識,而是爲什麼用了Distilling這個詞,我甚至覺得是不是某些人翻譯錯誤了。於是有道了一下,蒸餾也可以是提煉的意思,我就懂了。

屏幕快照 2020-03-05 下午8.46.48.png

 

在化學領域有一個概念叫沸點,不同液體有不同的沸點,假設酒精和水混合在一起,我們想提取混合物中的水,就可以將溫度加熱到小於水的沸點而大於酒精沸點的溫度,這樣酒精就揮發了。知識蒸餾也是用相似的手段將需要的知識從Teacher網絡蒸餾出來傳遞給Student網絡。

Teacher網絡和Student網絡

具體怎麼做呢?就是先構建一個非常複雜的網絡作爲Teacher網絡,默認它的模型預測準確性很高。然後再構建一個簡單的Student網絡,用Teacher網絡的輸出結果q和Student網絡的輸出結果p做Cross Entropy(交叉熵),y是真實的目標值,最終算Loss的公式如下。

這樣就達到了知識傳遞的問題,但是第二個問題來了。如果Teacher網絡的預測準確率很準,比如Teacher網絡是一個圖片識別模型,識別貓、狗、兔子,Teacher網絡很準的話,最後的輸出可能是以下這樣的概率分佈結果,非常不均勻

  1. 貓的概率:0.998

  2. 兔子的概率:0.0013

  3. 狗的概率:0.0007

這種結果被稱爲Hard Label,因爲真實傳遞下去的知識只有“貓”這一個結果,忽略了“兔子”比“狗”更像“貓”這樣的知識,因爲“兔子”和“狗”的權重太低。Teacher網絡需要Soft Label,怎麼做呢?在Softmax結果加入下面的公式:

其中zi是Softmax輸出的logit,T取1,那麼這個公式就是輸出Hard Label。如果T的取值大於1,T越大整個Label的分佈就變得越均勻,Hard Label就自然轉變成了Soft Label。

總結

上一篇文章講了模型壓縮技術中的剪枝、量化、共享權重,加上今天這篇知識蒸餾就比較完整了。感覺知識蒸餾這種方案比較適合終端設備的模型壓縮,特別是CV相關的模型。

非技術背景,純YY,有不對的請大神指正。

參考文章:

(1)https://zhuanlan.zhihu.com/p/90049906

(2)https://blog.csdn.net/nature553863/article/details/80568658

(3)https://zhuanlan.zhihu.com/p/81467832

(衷心感謝以上文章的作者們)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章