Motivation
ClovaAI今年ICCV做了還幾篇總結性的工作,該篇也類似,先總結當下做feature distillation的各個方向,總體的pipeline是選取student和teacher網絡結構的某個位置然後對特徵進行變換,最後拉進他們的距離
- Teacher transform: 爲了讓teacher和student的feature map一樣大(空間或者通道),會對teacher的feature上做reduce dim,但是這樣會損失一部分信息,比如Attention Transfer是去除通道維度,在各個位置求和
- Student transform: 爲了讓teacher和student的feature map一樣大(空間或者通道),會對student的feature進行變換,比如FitNets就加了1x1卷積,這樣不會使student的feature信息產生丟失
- Distillation feature position: 在網絡的哪些地方做feature distillation,Attention Transfer那篇文章在每組的最後一個輸出都做,而FitNets僅在最後一組的最後輸出做。
- Distance function: 最常用的就是L1、L2聚集
Method
-
Teacher transform: 爲了不丟失信息,採用margin ReLU,正值都保留,負值被抑制。這樣的話就不用學習精確的“沒有用”的負值,而集中精力學習“有用”的正值。而Heo提出的AB正值也不用學習精確,就捨棄太多了
,m是各個通道負值的期望,在訓練中動態計算
-
Student transform: 採用和HitNets一樣,在student後加1x1卷積
-
Distillation feature position: 在ReLu前,爲了保留正值和負值
-
Distance function
我們的蒸餾是在ReLu前做的,如果teacher小於0時,student比它小就不必懲罰,因爲經過ReLU後是一樣的
5. 總結
6. 其他細節
關於bn的問題,在進行KD時,teacher的bn應該是training mode,而且爲了和teacher一致,我們在做student transform時1x1卷積後加了bn
Experiments
CIFAR100,22、13、7三種方法是feature distillation + output distillation,其他都是默認的無output distillation
ImageNet,student ResNet50超過了ResNet152的性能
Object Detection、Semantic segmentation都有提升
Analysis
分析蒸餾到底有木有拉進teacher和student的距離
從下表看到,僅在訓練早期進行蒸餾的方法如FitNets、AB,student和teacher的KL距離反而變大了。
訓練過程中持續蒸餾的方法如KD、AT等確實拉進了KL距離,而我們的方法則把KL距離拉得更近,效果也最好
baseline是L2 loss用在每個階段的最後一層輸出,改變位置到ReLU前提升最多