知識蒸餾(Knowledge Distillation)

mark一下,感謝作者分享!
https://blog.csdn.net/nature553863/article/details/80568658

1、Distilling the Knowledge in a Neural Network

Hinton的文章"Distilling the Knowledge in a Neural Network"首次提出了知識蒸餾(暗知識提取)的概念,通過引入與教師網絡(teacher network:複雜、但推理性能優越)相關的軟目標(soft-target)作爲total loss的一部分,以誘導學生網絡(student network:精簡、低複雜度)的訓練,實現知識遷移(knowledge transfer)。

如上圖所示,教師網絡(左側)的預測輸出除以溫度參數(Temperature)之後、再做softmax變換,可以獲得軟化的概率分佈(軟目標),數值介於0~1之間,取值分佈較爲緩和。Temperature數值越大,分佈越緩和;而Temperature數值減小,容易放大錯誤分類的概率,引入不必要的噪聲。針對較困難的分類或檢測任務,Temperature通常取1,確保教師網絡中正確預測的貢獻。硬目標則是樣本的真實標註,可以用one-hot矢量表示。total loss設計爲軟目標與硬目標所對應的交叉熵的加權平均(表示爲KD loss與CE loss),其中軟目標交叉熵的加權係數越大,表明遷移誘導越依賴教師網絡的貢獻,這對訓練初期階段是很有必要的,有助於讓學生網絡更輕鬆的鑑別簡單樣本,但訓練後期需要適當減小軟目標的比重,讓真實標註幫助鑑別困難樣本。另外,教師網絡的推理性能通常要優於學生網絡,而模型容量則無具體限制,且教師網絡推理精度越高,越有利於學生網絡的學習。

教師網絡與學生網絡也可以聯合訓練,此時教師網絡的暗知識及學習方式都會影響學生網絡的學習,具體如下(式中三項分別爲教師網絡softmax輸出的交叉熵loss、學生網絡softmax輸出的交叉熵loss、以及教師網絡數值輸出與學生網絡softmax輸出的交叉熵loss):

聯合訓練的Paper地址:https://arxiv.org/abs/1711.05852

2、Exploring Knowledge Distillation of Deep Neural Networks for Efficient Hardware Solutions

這篇文章將total loss重新定義如下:

GitHub地址:https://github.com/peterliht/knowledge-distillation-pytorch

total loss的Pytorch代碼如下,引入了精簡網絡輸出與教師網絡輸出的KL散度,並在誘導訓練期間,先將teacher network的預測輸出緩存到CPU內存中,可以減輕GPU顯存的overhead:

  1. def loss_fn_kd(outputs, labels, teacher_outputs, params):
  2. """
  3. Compute the knowledge-distillation (KD) loss given outputs, labels.
  4. "Hyperparameters": temperature and alpha
  5. NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
  6. and student expects the input tensor to be log probabilities! See Issue #2
  7. """
  8. alpha = params.alpha
  9. T = params.temperature
  10. KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
  11. F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
  12.                F.cross_entropy(outputs, labels) * (1. - alpha)
  13. return KD_loss

3、Ensemble of Multiple Teachers

第一種算法:多個教師網絡輸出的soft label按加權組合,構成統一的soft label,然後指導學生網絡的訓練:

第二種算法:由於加權平均方式會弱化、平滑多個教師網絡的預測結果,因此可以隨機選擇某個教師網絡的soft label作爲guidance:

第三種算法:同樣地,爲避免加權平均帶來的平滑效果,首先採用教師網絡輸出的soft label重新標註樣本、增廣數據、再用於模型訓練,該方法能夠讓模型學會從更多視角觀察同一樣本數據的不同功能:

Paper地址:

https://www.researchgate.net/publication/319185356_Efficient_Knowledge_Distillation_from_an_Ensemble_of_Teachers

4、Hint-based Knowledge Transfer

爲了能夠誘導訓練更深、更纖細的學生網絡(deeper and thinner FitNet),需要考慮教師網絡中間層的Feature Maps(作爲Hint),用來指導學生網絡中相應的Guided layer。此時需要引入L2 loss指導訓練過程,該loss計算爲教師網絡Hint layer與學生網絡Guided layer輸出Feature Maps之間的差別,若二者輸出的Feature Maps形狀不一致,Guided layer需要通過一個額外的迴歸層,具體如下:

具體訓練過程分兩個階段完成:第一個階段利用Hint-based loss誘導學生網絡達到一個合適的初始化狀態(只更新W_Guided與W_r);第二個階段利用教師網絡的soft label指導整個學生網絡的訓練(即知識蒸餾),且total loss中soft target相關部分所佔比重逐漸降低,從而讓學生網絡能夠全面辨別簡單樣本與困難樣本(教師網絡能夠有效辨別簡單樣本,而困難樣本則需要藉助真實標註,即hard target):

Paper地址:https://arxiv.org/abs/1412.6550

GitHub地址:https://github.com/adri-romsor/FitNets

5、Attention to Attention Transfer

通過網絡中間層的attention map,完成teacher network與student network之間的知識遷移。考慮給定的tensor A,基於activation的attention map可以定義爲如下三種之一:

隨着網絡層次的加深,關鍵區域的attention-level也隨之提高。文章最後採用了第二種形式的attention map,取p=2,並且activation-based attention map的知識遷移效果優於gradient-based attention map,loss定義及遷移過程如下:

Paper地址:https://arxiv.org/abs/1612.03928

GitHub地址:https://github.com/szagoruyko/attention-transfer

6、Flow of the Solution Procedure

暗知識亦可表示爲訓練的求解過程(FSP: Flow of the Solution Procedure),教師網絡或學生網絡的FSP矩陣定義如下(Gram形式的矩陣):

訓練的第一階段:最小化教師網絡FSP矩陣與學生網絡FSP矩陣之間的L2 Loss,初始化學生網絡的可訓練參數:

訓練的第二階段:在目標任務的數據集上fine-tune學生網絡。從而達到知識遷移、快速收斂、以及遷移學習的目的。

Paper地址:

http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf

7、Knowledge Distillation with Adversarial Samples Supporting Decision Boundary

從分類的決策邊界角度分析,知識遷移過程亦可理解爲教師網絡誘導學生網絡有效鑑別決策邊界的過程,鑑別能力越強意味着模型的泛化能力越好:

文章首先利用對抗攻擊策略(adversarial attacking)將基準類樣本(base class sample)轉爲目標類樣本、且位於決策邊界附近(BSS: boundary supporting sample),進而利用對抗生成的樣本誘導學生網絡的訓練,可有效提升學生網絡對決策邊界的鑑別能力。文章採用迭代方式生成對抗樣本,需要沿loss function(基準類得分與目標類得分之差)的梯度負方向調整樣本,直到滿足停止條件爲止:

loss function:

沿loss function的梯度負方向調整樣本:

停止條件(只要滿足三者之一):

結合對抗生成的樣本,利用教師網絡訓練學生網絡所需的total loss包含CE loss、KD loss以及boundary supporting loss(BS loss):

Paper地址:https://arxiv.org/abs/1805.05532

8、Label Refinery:Improving ImageNet Classification through Label Progression

這篇文章通過迭代式的誘導訓練,主要解決訓練期間樣本的crop與label不一致的問題,以增強label的質量,從而進一步增強模型的泛化能力:

誘導過程中,total loss表示爲本次迭代(t>1)網絡的預測輸出(概率分佈)與上一次迭代輸出(Label Refinery:類似於教師網絡的角色)的KL散度:

文章實驗部分表明,不僅可以用訓練網絡作爲Label Refinery Network,也可以用其他高質量網絡(如Resnet50)作爲Label Refinery Network。並在誘導過程中,能夠對抗生成樣本,實現數據增強。

GitHub地址:https://github.com/hessamb/label-refinery

9、Miscellaneous

-------- 知識蒸餾可以與量化結合使用,考慮了中間層Feature Maps之間的關係,可參考:

https://blog.csdn.net/nature553863/article/details/82147933

-------- 知識蒸餾與Hint Learning相結合,可以訓練精簡的Faster-RCNN,可參考:

https://blog.csdn.net/nature553863/article/details/82463249

-------- 模型壓縮方面,更爲詳細的討論,請參考:

https://blog.csdn.net/nature553863/article/details/81083955

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章