深度學習論文筆記(rethinking knowledge distillation)——On the Efficacy of Knowledge Distillation

這篇文章非常有意思,本文文字部分較多,主要記錄了個人對於文章的一些思考

前言

《On the Efficacy of Knowledge Distillation》於2019年發表在ICCV上。通過實驗,作者發現了一個“怪相”,準確率越高的模型並不一定就是好的teacher模型,對於同一個student模型而言,teacher模型越大,teacher模型的準確率越高,知識蒸餾得到的student模型性能卻越差。作者認爲是student模型與teacher模型的容量相差太大,導致student模型無法模擬teacher模型,從而出現上述怪相。爲了解決上述“怪相”,作者提出在訓練teacher模型時進行early stop,即teacher模型訓練一定輪數後停止訓練,此時teacher模型可能還未較好擬合訓練集。

按我的理解,知識蒸餾的實質是想讓student模型模擬teacher模型的輸出,從而在student模型的參數空間中找到一個解,這個解是teacher模型找到的解的近似解,當student模型與teacher模型容量差距太大時,將導致student模型無法找到近似解,但是由此又引發一個問題,該如何防止student模型與teacher模型的容量差距太大呢?論文並沒有解釋。作者從另外一個角度提出瞭解決方案,通過在訓練時early stop teacher模型,讓teacher模型找到的解儘可能簡單,從而讓student模型儘可能找到近似解。

2015年,hinton在論文《Distilling the Knowledge in a Neural Network》中提出了知識蒸餾,本文中的“知識蒸餾”均指該方法。


疑問:高準確率的大模型一定就是好teacher嗎?

在這裏插入圖片描述
上圖展示了當student模型爲resnet18,teacher模型爲resnet18、34、50時,在IMageNet數據上,利用知識蒸餾得到的student模型的準確率,第一行表示不使用知識蒸餾訓練的resnet18的準確率,從上圖可以看出兩個問題:

  • teacher模型越大、準確率越高,student模型的準確率卻越低。
  • 模型結構相同,使用了知識蒸餾的resnet18準確率不如未使用知識蒸餾的resnet18

問題二

第二個問題爲本人提出,問題二駁斥了目前網上對於知識蒸餾爲什麼有效的一個解釋,即soft label相比於hard label,可以提供更多類與類之間相似性的信息,這類信息將有助於student模型區分類。但是從上圖數據來看,當teacher模型爲resnet18時,給出的soft label也可以反映類與類之間的相似性,但是student模型的準確率卻並沒有更高,因此個人不是很認同這個觀點。

針對於問題二,個人的理解是——若student模型與teacher模型的結構不同,teacher模型性能優於student模型,此時知識蒸餾可能可以讓student模型在參數空間中找到一個與teacher模型解近似的解,這個近似解通常不如teacher模型的解,但可能可以讓student模型與teacher模型性能近似,從而提高student模型的性能。若student模型與teacher模型的結構相同,teacher模型的性能與student模型性能基本一致,此時student找到的近似解並不一定就能提高student模型的準確率。


問題一

針對問題一,作者提出了三種假設

  • teacher模型越大,給出的soft label越接近於hard label,給出的信息越來越近似於hard label
  • student可以模擬teacher,但這不能導致student的泛化性能提升,即知識蒸餾是無效的
  • student無法模擬teacher

假設一
高溫可以防止teacher模型的soft label與hard label近似,防止soft label給出的信息近似於hard label,但是當溫度爲20時,在ImageNet數據集上依然會出現“怪相”,如下圖:
在這裏插入圖片描述
因此作者否定了這個假設


假設二、假設三
在這裏插入圖片描述
當student模型爲ResNet18,teacher模型分別爲ResNet18、34、50時,在ImageNet上運用知識蒸餾的結果如上,KD(Train)表示訓練集上的KD loss值,CE(Train)表示訓練集上交叉熵loss的值,有(ES KD)符號標記的數據可以暫時不看。

如果student模型可以模擬teacher模型,那麼在訓練集上的KD loss應該趨近於0,依據上圖(無ES KD部分),我們可以得知student模型無法模擬teacher模型(KD loss大於1),並且模型越大,KD loss也越大,這說明student模型越來越難以模擬teacher模型。由此推翻了假設二,印證了假設三

可以看到ResNet18(無ES KD)一行的KD loss值非常大,這並不能說明問題,由於student模型和teacher模型共享參數空間,teacher模型找到的解存在於student模型的參數空間中,這個解可以使KD loss取值接近於0,這裏KD loss這麼大,很可能歸因於優化算法不夠智能,或是初始化參數不一致,導致無法找到teacher模型的解。


可能的解決方案

上一節我們通過實驗證明了假設三,具體而言,即teacher模型越大,student模型越難在參數空間中找到一個不錯的近似解(體現在KD loss會隨着teacher模型容量增大而增大),導致student模型性能越來越糟糕。

依據上述假設,作者給出了三個可能的解決方案

  • 初期使用交叉熵+KD loss作爲損失函數,訓練一段時間後,只使用交叉熵損失函數
  • 使用Sequential knowledge distillation,即選擇一個容量位於student模型與teacher模型之間的middle模型,先將teacher的知識蒸餾到middle模型,在將middle模型的知識蒸餾到student模型
  • 對teacher模型使用early stop,即teacher模型訓練一定epoch後停止訓練,接着進行蒸餾

解決方案一
由於student模型難以找到近似解,那就是用知識蒸餾做一個pretrain,接着用交叉熵損失函數,以求找到一個儘可能好的解,實驗結果爲下圖(含有ES KD符號)
在這裏插入圖片描述
可以看到,ES KD的性能優於知識蒸餾,但是仍然會出現“怪相”。


解決方案二
在這裏插入圖片描述
看上圖第一、三行最後一列數據,基本沒有差別,這裏其實有一個核心的問題,要選擇怎樣的middle模型,才能即讓middle模型找到teacher模型解近似的解,又讓student模型(small模型)找到與middle模型解近似的解,這似乎把問題變得更加複雜,解決方案二是不能work的。


解決方案三

通過early stop teacher模型的訓練,讓teacher模型找到的解儘可能簡單,從而方便student模型找到對應的近似解。

在CIFAR10數據集上使用上述策略,結果如下:
在這裏插入圖片描述

x軸表示teacher模型訓練的epoch數目,可以看到,在一定範圍內,當teacher模型越大,student模型的錯誤率越小,比如第一幅圖中的WRN-4、6、8,當epoch太小時,此時teacher模型找到的解可能很糟糕,這導致student模型性能較差,當epoch太大時,teacher模型找到的解太複雜,student模型難以找到近似解,導致student模型性能較差。


如果您想了解更多有關深度學習、機器學習基礎知識,或是java開發、大數據相關的知識,歡迎關注我們的公衆號,我將在公衆號上不定期更新深度學習、機器學習相關的基礎知識,分享深度學習中有趣文章的閱讀筆記。

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章