DeiT:使用Attention蒸餾Transformer

題目:Training data-efficient image transformers & distillation through attention

【GiantPandaCV導語】Deit是一個全Transformer的架構,沒有使用任何的卷及操作。其核心是將蒸餾方法引入VIT的訓練,引入了一種教師-學生的訓練策略,提出了token-based distillation。有趣的是,這種訓練策略使用卷積網絡作爲教師網絡進行蒸餾,能夠比使用transformer架構的網絡作爲教師取得更好的效果。

簡介

之前的ViT需要現在JFT-300M大型數據集上預訓練,然後在ImageNet-1K上訓練才能得到出色的結果,但這藉助了額外的數據。

ViT文中也表示:“do not generalize well when trained on insufficient amounts of data”數據量不足會導致ViT效果變差。

針對以上問題,Deit核心共享是使用了蒸餾策略,能夠僅使用ImageNet-1K數據集就就可以達到83.1%的Top1。

文章貢獻如下:

  • 僅使用Transformer,不引入Conv的情況下也能達到SOTA效果。

  • 提出了基於token蒸餾的策略,這種針對transformer的蒸餾方法可以超越原始的蒸餾方法。

  • Deit發現使用Convnet作爲教師網絡能夠比使用Transformer架構取得更好的效果。

知識蒸餾

Knowledge Distillation(KD)最初被Hinton提出,與Label smoothing動機類似,但是KD生成soft label的方式是通過教師網絡得到的。

KD可以視爲將教師網絡學到的信息壓縮到學生網絡中。還有一些工作“Circumventing outlier of autoaugment with knowledge distillation”則將KD視爲數據增強方法的一種。

KD能夠以soft的方式將歸納偏置傳遞給學生模型,Deit中使用Conv-Based架構作爲教師網絡,將局部性的假設通過蒸餾方式引入Transformer中,取得了不錯的效果。

本文提出了兩種KD:

  • Soft Distillation: 使用KL散度衡量教師網絡和學生網絡的輸出,即Hinton提出的方法。

\[\mathcal{L}_{\text {global }}=(1-\lambda) \mathcal{L}_{\mathrm{CE}}\left(\psi\left(Z_{\mathrm{s}}\right), y\right)+\lambda \tau^{2} \mathrm{KL}\left(\psi\left(Z_{\mathrm{s}} / \tau\right), \psi\left(Z_{\mathrm{t}} / \tau\right)\right) \]

其中\(Z_s,Z_t\)分別代表學生網絡的logits輸出和教師網絡的logits輸出。

  • Hard-label Distillation: 本文提出的一個KD變體,將教師網絡得到的hard輸出作爲label,即\(y_t=argmax_cZ_t(c)\),該方法是無需調參的。

\[\mathcal{L}_{\text {global }}^{\text {hardDistill }}=\frac{1}{2} \mathcal{L}_{\mathrm{CE}}\left(\psi\left(Z_{s}\right), y\right)+\frac{1}{2} \mathcal{L}_{\mathrm{CE}}\left(\psi\left(Z_{s}\right), y_{\mathrm{t}}\right) \]

Deit蒸餾過程

在ViT架構基礎上引入了Distillation token,其地位與Class token相等,並且參與了整體信息的交互過程。

Distillation token讓模型從教師模型輸出中學習,文章發現:

  • 最初class token和distillation token區別很大,餘弦相似度爲0.06

  • 隨着class 和 distillation embedding互相傳播和學習,通過網絡逐漸變得相似,到最後一層,餘弦相似度爲0.93

實驗

Deit模型follow了Vision Transformer的設置,訓練策略有所不同,僅使用Linear classifier,而不是用MLP head。

本文提出了Deit的系列模型:

  • Deit-B:代表與ViT-B有相同架構的模型

  • Deit-B|384 : 代表對Deit-B進行finetune,分辨率提升到384

  • Deit-S/Deit-Ti:更小的模型,修改了head數量。

實驗1: 選取不同教師網絡的效果

可以發現使用RegNet作爲教師網絡可以取得更好的性能表現,Transformer可以通過蒸餾來繼承歸納偏差。

同時還可以發現,學生網絡可以取得超越老師的性能,能夠在準確率和吞吐量權衡方面做的更好。

PS:不太明白這裏對比的時候爲何不選取ViT-H(88.5%top1)作爲教師模型?

實驗2: 測試不同蒸餾方法

實驗證明:hard-label distillation能夠取得更好的結果。

實驗3: 與SOTA模型進行比較

訓練細節

  • 使用truncated normal distribution來進行初始化

  • soft蒸餾參數:\(\tau=3,\lambda=0.1\)

  • 數據增強:Autoaugment,Rand-augment,random erasing,Cutmix,Mixup,Label Smoothing等

  • 訓練300個epoch需要花費37個小時,使用兩個GPU

回顧

問: 爲什麼不同架構之間也可以蒸餾?蒸餾能夠將局部性引入transformer架構嗎?

答:教師模型能夠將歸納偏置以soft的方式傳遞給學生模型。

問: 性能增強歸功於蒸餾 or 複雜度數據增強方法?

答:蒸餾策略是有效的,但是相比ViT,Deit確實引入了非常多的數據增強方法,直接與ViT比較還是不夠公平的。Deit測試了多種數據增強方法,發現大部分數據增強方法能夠提高性能,這還是可以理解爲Transformer缺少歸納偏置,所以需要大量數據+數據增強。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章