【Distilling】《Learning Efficient Object Detection Models with Knowledge Distillation》

在這裏插入圖片描述

NIPS-2017



1 Background and Motivation

伴隨着CNN 的發展,object detection 的精度有了極大的提升,然而要落地應用,實時性仍是個考驗!

知識蒸餾,Knowledge Distillation:a shallow or compressed model trained
to mimic the behavior of a deeper or more complex model can recover some or all of the accuracy drop.(更多應用可參考下面三篇文章)

雖然能在保留精度的同時,能壓縮模型,提升速度,但只在分類任務上得到了印證,在更復雜的 object detection 上還有待探索!

想要把 Knowledge Distillation 應用到 object detection 上,相比於 classification 任務,有如下問題和挑戰:

  • suffers more degradation with compression,因爲目標檢測任務的標籤更 expensive,usually less voluminous.(感覺應該是說,目標檢測任務標籤信息量更大,根據標籤學到的模型更爲複雜,壓縮後損失更多!就像高清壁紙和普通壁紙,同樣的壓縮比,高清壁會更模糊,這裏的高清壁紙就可以理解爲更 expensive)
  • 分類任務中,each class is equally important,然後目標檢測任務中,background class is far more prevalent
  • 目標檢測任務更爲複雜,both classification and bounding box regression
  • 在這裏插入圖片描述

最後一個挑戰沒看懂(可能沒有看用知識蒸餾遷移不同領域任務的論文,get 不到作者的點)

2 Advantages / Contributions

  • the first successful demonstration 目標檢測中應用知識蒸餾壓縮模型
  • 提出了新的 loss 來處理上述的問題和挑戰
  • perform comprehensive empirical evaluation using multiple large-scale public benchmarks(在多個公共數據集上做了大量的驗證實驗)
  • 對泛化性問題和欠擬合問題給出了自己的 insights

3 Method

作者採用的 Faster R-CNN framework,從主幹網絡,RPN,RCN(頭部)三個部分,進行了知識蒸餾!

在這裏插入圖片描述

  • 主幹網絡:adaptation layers for hint learning
  • 分類任務:weighted cross entropy loss for severe category imbalance issue
  • 迴歸任務:teacher bounded regression loss,teacher’s regression output as a form of upper bound,學生網絡迴歸的更優則無損失

公式如下:

在這裏插入圖片描述

  • NNMM 分別是對應部分的 batch-size 大小,λ\lambdaγ\gamma 是超參數,作者這裏分別設定爲 1 和 0.5
  • LclsL_{cls},分類損失包括 hard target 和知識蒸餾中的 soft target
  • LregL_{reg},迴歸損失包括 smooth L1 和新提出的 teacher bounded L2 regression loss
  • LHintL_{Hint},是主幹的損失

下面來詳細看看各個部分的公式細節

3.1 Knowledge Distillation for Classification with Imbalanced Classes

老師網絡的預測結果 PtP_t 可以表示如下
在這裏插入圖片描述
其中,ZtZ_t 是老師網絡的 logits,TT 是溫度(細節介紹可以參考 【Distilling】《Distilling the Knowledge in a Neural Network》(arXiv-2015, In NIPS Deep Learning Workshop, 2014)

同樣,學生網絡的輸出 PsP_s 也可以表示成如下形式
在這裏插入圖片描述

在知識蒸餾方法中,學生網絡的優化損失如下:
在這裏插入圖片描述

  • LhradL_{hrad} 就是用 gt 監督的 cross entropy
  • LsoftL_{soft} 就是利用了老師網絡的信息的 soft loss,好處是 The soft labels contain information about the relationship between different classes as discovered by teacher
  • μ\mu 是超參數,來 balance hard and soft loss

在分類任務中,分類錯誤只會來自於 foreground categories,而目標檢測任務中的分類子任務,background and foreground 的錯誤 can dominate the error,foreground 的誤分概率比較低,作者通過增大背景類的權重來處理這個問題,形式如下,
在這裏插入圖片描述
多加了一個 wcw_cw0=1.5w_0 = 1.5 for the background class and wi=1w_i = 1 for all the others(注意這裏 PtPt 不是 ground truth 的 one-hot 編碼,所有交叉熵中加權重是有效的)

作者討論下溫度 TT 的問題,我們知道 TT 越大,會縮小各類概率分佈的差距,參考 【Distilling】《Distilling the Knowledge in a Neural Network》(arXiv-2015, In NIPS Deep Learning Workshop, 2014)。這在小任務中(such as classification on small datasets like MNIST)非常適用!
缺點是,也會增大噪聲的分佈,不利於學習,不適用於大任務,例如 classification on larger datasets 或者 object detection!

作者實驗中把 TT 設置爲了 1

3.2 Knowledge Distillation for Regression with Teacher Bounds

  • regression 不像 classification task,它是 unbounded,
  • In addition, the teacher may provide regression direction that is contradictory to the ground truth direction.

基於以上兩點,我們不能直接學 teacher network 的迴歸值(第二點,嗯……)!而把損失設計成如下形式:
在這裏插入圖片描述

  • mm is a margin,vv 權重,作者設置爲了 0.5
  • yregy_{reg} denotes the regression ground truth label,就是 proposal 和 gt 之間的迴歸量
  • RtR_tRsR_s 是 teacher 和 student 網絡學出來的迴歸量
  • LsL1L_{sL1} 就是普通的 smooth L1 迴歸 loss

如果學生網絡學出來的沒有老師網絡好,纔會有懲罰!也就是達到老師的要求就不強求了! LbL_b 不侷限於是 L2 Loss 的形式,L1 或者 smooth L1 都行!!!

3.3 Hint Learning with Feature Adaptation

在這裏插入圖片描述
上述論文中證明,using the intermediate representation of the teacher as hint can help the training process and improve the final performance of the student.
在這裏插入圖片描述
LHintL_{Hint} 是學主幹的監督信息損失,形式如下
在這裏插入圖片描述
在這裏插入圖片描述

V,ZV, Z分別是來自老師和學生網路的 feature vectors,必須 h,w,channels 相同,有時候需要加 adaption layer(full connection 或者 111*1 convolution) 來使得 ZZVV 維度一模一樣!

作者發現,即使 V,ZV, Z dimension 一樣,加了 adaption layer 效果會更好,adaption layer 也可以用在不同的模型之間,例如 VGG16 and AlexNet !

4 Experiments

4.1 Datasets

  • KITTI
  • PASCAL VOC 2007
  • MS COCO
  • ImageNet DET benchmark (ILSVRC 2014)

4.2 Overall Performance

在這裏插入圖片描述

Teacher 列中 - 表示,teacher 和 student 是同網絡,可以看出,大網絡作爲 teacher 能帶來更多的提升

在這裏插入圖片描述
這個表是蒸餾分辨率(downsampling the input size quadratically reduces convolutional resources and speeds up computation.)

老師網絡 688,學生網絡 344!精度相當,速度接近 x2

4.3 Speed-Accuracy Trade off in Compressed Models

在這裏插入圖片描述

4.4 Ablation Study

在這裏插入圖片描述
VGG16 爲 Teacher,Tucker 爲 student,在 PASCAL 和 KITTI 上評估!對比了不同的蒸餾方式,可以看出,作者設計的 teacher bounded regression loss 比設計成 regression loss 形式好, weighted cross entropy losscross entropy loss 好,adaptation layers for hint learning 比沒有 adaptation layers 的好!

在這裏插入圖片描述

  • Distillation improves generalization
    (‘Car’ shares more common visual characteristics with ‘Truck’ than with ’Person’)

  • Hint helps both learning and generalization
    目標檢測任務還路漫漫,It seems the learning algorithm is suffering from the saddle point problem. the hint may provide an effective guidance to avoid the problem by directly having a guidance at an intermediate layer.(避免陷入局部最優解???)

5 Conclusion(own)

  • In object detection, however, failing to discriminate between background and foreground can dominate the error, while the frequency of having misclassification between foreground categories is relatively rare.
  • 可以蒸餾分辨率
  • Hint learning + knowledge distilling
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章