Distilling transformers into simple neural networks with unlabeled transfer data論文解讀

Distilling transformers into simple neural networks with unlabeled transfer data

論文地址:https://arxiv.org/pdf/1910.01769.pdf

motivation

一般來說,蒸餾得到的student模型與teacher模型的準確率還存在差距。文章利用大量in-domain unlabeled transfer set以及有限數量的標記訓練實例來彌補這一差距。

運用在文本分類任務上。

文章提出的兩個蒸餾方法:

  • hard distillation:用fine-tuned teacher對大量的無標籤數據進行標註,標註硬標籤。然後用這些augmented data來對student進行監督學習。loss函數是交叉墒。
  • soft distillation:用教師模型在unlabeled data上生成的logits和內部表示,來對student進行不同蒸餾方式(不同loss函數)的訓練。

模型的輸入用Wordpiece tokenization

student模型

embedding層+BiLSTM層+最大池化層(因爲如果只用最後一個hidden state,對於長語句來說信息量不夠)

損失函數爲交叉墒函數

teacher模型

用標註數據來fine tune預訓練模型,用的是最後一層的 [CLS]向量,loss是交叉墒。

選什麼特徵來蒸餾

teacher logits

對於unlabeled數據,教師產生的logits和學生生成的分類score之間的loss,均方誤差。

hidden teacher representations

用教師學到的中間表示來指導學生模仿自己。文中用的是教師模型最後一層。

因爲兩個模型結構不同,最後一層的維度也會不同。用Gelu激活函數進行轉換到相同維度。
然後依然是兩個模型最後一層表示之間的均方誤差作爲loss。

(文中提到,他們發現均方誤差得到的結果比KL散度更好些)

整體框架

損失函數:

三個loss函數組合起來,不同的權重。

loss函數

較高的α值使學生模型更多地關注容易實現的目標。而較高的γ可使學生專注於困難的目標,並使模型適應嘈雜的地面真相標籤。後者不是這項工作的重點,因此在進一步分析中將其省略。

訓練方式:

  1. 聯合訓練,三個loss函數並在一起,
  2. 逐漸解凍的分層訓練。
  • 第一步,先訓練LRL ,學習參數,模仿teacher最後一層的表示。
  • 第二步,以LCE和LLL作爲loss,但是不能一下子優化所有參數,會造成災難性遺忘。因此,將每層的參數frozen,然後從最後一層一層的解凍。直到收斂。
  1. 先蒸餾,再finetune,與方式2相似,
  • 第一步,先以LRL 和LLL作爲loss,不需要labeled數據。
  • 第二步,用labeled數據進行fine tune,loss爲LCE,微調的時候和方式2一樣,逐層解凍。
    這樣相當於第一步得到了一個蒸餾後的student,然後之後就可以根據不同的任務數據來fine tune它。

實驗

四個數據集:

  • IMDB:電影評論情感分類
  • Elec:亞馬遜電子產品的情感分類
  • DbPedia:Wikipedia的主題分類
  • Ag News:新聞文章的主題分類。

一些參數

數據比例 train:validation=9:1

Tensorflow

4 Tesla V100 gpus

Adadelta優化器 + early stopping (也用了Adam,Adam收斂更快,但是最終結果沒有Adadelta好)

所有層dropout=0.4,Bi-LSTM層dropout=0.2

300d的glove預訓練詞向量。

LSTM隱層維度=600,batch size=64,

loss函數中的 α = β = 10, γ = 1

教師模型

選的是BERT-base和BERT-large

學生模型:

  1. Bi-LSTM encoder,最後一個隱向量+soft Max作分類,用基礎的空格分詞法。只用labeled data,交叉墒爲損失函數來訓練。
  2. 不蒸餾的Bi-LSTM encoder+Max pooling,用wordpiece tokenization,訓練loss和1相同。
  3. 蒸餾的student。和上文中提到的一樣,三個loss,3種不同的訓練方式。

數據處理部分:

有的數據集沒有unlabeled data,所以就把數據集分爲兩部分,一部分有標籤,另一部分去掉標籤作爲unlabeled data。

實驗結果:

用BERT-base做老師

在這裏插入圖片描述
可以看出用了wordpiece和加了Max pooling層的模型比普通的RNN模型效果好。然後通過蒸餾可以明顯的提升學生模型的準確率,甚至高於teacher。

用更大的BERT-large做老師:

在這裏插入圖片描述
可以看出,teacher性能越好,蒸餾得到的student也越好。(符合思維,好老師教出好學生)

參數量比較

Distilled Student BERT Base BERT Large
13M 110M 340M

Distilling Hard Targets vs. Soft Logits

hard distillation是指,finetune之後的teacher對unlabeled data進行預測標註,然後用原本的標註數據和teacher標註的數據一起,對student進行蒸餾。不涉及到logits和最後層表示。
在這裏插入圖片描述

Distillation with Less Training Labels

  • 每類留500個標註數據的蒸餾結果:
    在這裏插入圖片描述
  • 每類留100個標註數據:

每類100個,對BERT large進行fine tune時,DbPedia和Ag News兩個數據集的微調結果和之前500個的時候差不多,但是IMDB和Elec數據集只有50%的準確率,幾乎是隨機了。

於是文章又做了一個實驗,用這兩個數據集對pretrained BERT接着進行預訓練,然後用每類100個進行微調。最後再用每類100個進行蒸餾,得到的學生模型甚至超過了BERT large。

在這裏插入圖片描述

不同訓練方式的對比:

可以看出,方式2的結果最好,先學習最後一層的表示,再學習logits和CE。

在這裏插入圖片描述

總結

這篇論文的創新性不高,思路其實和Distilling Task-Specific Knowledge from BERT into Simple Neural Networks這篇論文差不多。

但是文章裏的實驗做的很詳細,該比較的點都比較了。論文也寫的很清晰易懂。

可以看出蒸餾到小模型確實是Bert這些大模型很好的應用方向。自己也做了相關的實驗,確實有效。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章