DeiT:訓練ImageNet僅用4卡不到3天的平民ViT | ICML 2021

論文基於改進訓練配置以及一種新穎的蒸餾方式,提出了僅用ImageNet就能訓練出來的Transformer網絡DeiT。在蒸餾學習時,DeiT以卷積網絡作爲teacher,能夠結合當前主流的數據增強和訓練策略來進一步提高性能。從實驗結果來看,效果很不錯

來源:曉飛的算法工程筆記 公衆號

論文: Training data-efficient image transformers & distillation through attention

Introduction


  Vision Transformer一般要先在大型計算設施上預訓練數以億計的圖片纔能有較好的性能,這極大地提高其應用門檻。爲此,論文基於ViT提出了可在ImageNet上訓練的Vision Transformer模型DeiT,僅需要一臺電腦(4卡)訓練不到三天(53小時的預訓練和可選的20小時微調)的時間。在沒有外部數據預訓練的情況下,在ImageNet上達到了83.1% 的最高精度。

  此外,論文還提出了一種針對Transformer的蒸餾策略,通過一個蒸餾token確保student網絡通過注意力從teacher網絡那裏進行學習。當使用卷積網絡作爲teacher網絡時,ImageNet上可達到85.2%的準確性。

  總體而言,論文主要有以下貢獻:

  • 通過實驗表明,在沒有外部數據的情況下,Vision Transformer也可以在ImageNet上達到SOTA的結果,而且僅需要4卡設備訓練三天。
  • 論文提出了一種基於蒸餾token的新蒸餾方法,這種用於Transformer的蒸餾方法大幅優於一般蒸餾方法。蒸餾token與class token的作用相同,都參與注意力計算中,只是蒸餾token的訓練目的在於復現teacher網絡的標籤預測。
  • 有趣的是,論文發現在使用新蒸餾方法時,用卷積網絡作爲teacher要比用另一個相同準確率的transformer的作爲teacher的效果要好。
  • 在Imagenet上預訓練的模型可以轉移到不同的下游任務(如細粒度分類),得到很不錯的性能。

Distillation through attention


Soft distillation

  一般的蒸餾方法都是Soft distillation,其核心目標是最小化teacher網絡和student網絡的softmax輸出之間的Kullback-Leibler散度。

  定義\(Z_t\)爲teacher網絡的logits輸出(輸入softmax的向量),\(Z_s\)爲student網絡的logits輸出。用\(\tau\)表示蒸餾溫度,\(\lambda\)表示平衡Kullback-Leibler散度損失(KL)和交叉熵損失(LCE)的權值,\(\psi\)表示softmax函數。定義soft distillation的目標函數爲:

Hard-label distillation

  論文提出了一種蒸餾的變體,將teacher網絡的預測標籤作爲蒸餾的GT標籤。假設\(y_t = argmax_c Z_t(c)\)是teacher網絡的預測標籤,與之相關的hard-label distillation目標爲:

  對於同一張圖片,teacher網絡預測的標籤可能隨着特定的數據增強而有所變化。從實驗結果來看,將預測標籤作爲蒸餾目標的做法比傳統的做法更好,不僅無額外參數,概念上還更簡單:teacher網絡預測的\(y_t\)與真實標籤\(y\)是相同的作用。
  此外,hard label也可以通過label smoothing轉換爲軟標籤,其中GT標籤具有\(1 - \varepsilon\)的概率,其餘類共享\(\varepsilon\)概率。在相關的實驗中,參數固定爲\(\varepsilon = 0.1\)

Distillation token

  論文提出的蒸餾方案如如圖2所示,在輸入的token序列中添加一個蒸餾token。蒸餾token與class token類似,通過self-attention與其它token交互並將最後一層中的對應輸出作爲網絡輸出,其訓練目標爲損失函數中的蒸餾損失部分。蒸餾token使得模型可以像常規蒸餾一樣從teacher網絡的輸出中學習,同時與class token保持互補的關係。

  論文發現,訓練後的輸入層class token和蒸餾token收斂到了完全不同的向量,平均餘弦相似度僅爲0.06。但隨着在網絡的計算,class和蒸餾token在越深層中的對應輸出逐漸變得更加相似,最後一層達到了較高的相似度(cos=0.93),但沒有完全相同。這是符合預期的,因爲兩個token的目標就是產生相似但不相同的目標。

  論文也嘗試替代實驗,用另一個class token代替teacher網絡的蒸餾token進行僞蒸餾學習。但無論如何隨機且獨立地初始化兩個class token,訓練後都會收斂到相同的向量(cos=0.999),其對應的輸出也是準相同的。這表明這個代替的class token不會對分類性能帶來任何影響,相比之下蒸餾token則能帶來顯著的提升。

Fine-tuning with distillation

  在分辨率增加的fine-tuning階段,同樣使用真實標籤和teacher網絡預測標籤進行訓練。此時需要一個具有相同目標分辨率的teacher網絡,可通過FixRes的做法從之前的低分辨率teacher網絡中轉換。論文也嘗試了只用真實標籤進行fine-tuning,但這導致了性能的降低。

Classification with our approach:joint classifiers

  在測試時,網絡輸出的class token和蒸餾token都用於標籤分類。論文的建議做法是將這兩個token獨立預測後再融合,即將兩個分類器的softmax輸出相加再進行預測。

Transformer models


  DeiT的架構設計與ViT相同,唯一的區別是訓練策略和蒸餾token,訓練策略的區別如表9所示。此外,在預訓練時不使用MLP,僅使用線性分類器。

  爲避免混淆,用ViT來指代先前工作中的結果,用DeiT來指代論文的結果。如果未指定,DeiT指的是DeiT-B,與ViT-B具有相同的架構。當以更大的分辨率fine-tune DeiT時,論文會在名字的最後附加分辨率,例如DeiT-B↑384。最後,當使用論文提出的蒸餾方法時,論文會用一個蒸餾符號將其標識爲DeiT⚗.。

  如表1所示,DeiT-B的結構與ViT-B完全一樣,參數固定爲\(D = 768\)\(h = 12\)\(d = D/h = 64\)。另外,論文設計了兩個較小的模型:DeiT-S和DeiT-Ti,減少了head的數量,\(d\)保持不變。

Experiment


  不同類型的teacher網絡的蒸餾效果。

  不同蒸餾策略的對比實驗。

  不同網絡以及蒸餾策略之間的結果差異,值越小差異越小。

  蒸餾策略與訓練週期的關係。

  整體性能的對比。

  ImageNet上預訓練模型的在其它訓練集上的遷移效果。

  不同優化器、數據增強、正則化的對比,尋找最佳的訓練策略和配置。

  224分辨率預訓練的DeiT在不同數據集上用不同分辨率fine-tune的效果。

Conclusion


  論文基於改進訓練配置以及一種新穎的蒸餾方式,提出了僅用ImageNet就能訓練出來的Transformer網絡DeiT。在蒸餾學習時,DeiT以卷積網絡作爲teacher,能夠結合當前主流的數據增強和訓練策略來進一步提高性能。從實驗結果來看,效果很不錯。



如果本文對你有幫助,麻煩點個贊或在看唄~
更多內容請關注 微信公衆號【曉飛的算法工程筆記】

work-life balance.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章