Fast BERT論文解讀


轉載來源:https://zhuanlan.zhihu.com/p/143027221
alt

自從 BERT 出現後,似乎 NLP 就走上了大力出奇跡的道路。模型越來越大參數越來越多,這直接導致我們需要的資源和時間也越來越多。發文章搞科研似乎沒有什麼,但是這些大模型很難在實際工業場景落地,不只是因爲成本過高,也因爲推理速度不支持線上實際情況。最近好多文章都開始針對 BERT 進行瘦身,不管是蒸餾還是減少層數還是參數共享等,都是爲了 BERT 系列模型能夠更小更快的同時不丟失精度。
論文地址:https://arxiv.org/pdf/2004.02178.pdf​arxiv.org

一、概述

FastBERT 是 ACL2020 新鮮出爐的一篇關於提高 BERT 推理速度的文章,這篇文章的思想還是很巧妙的,作者發現 12 層的 BERT 用來對一些簡單問題做分類有點大材小用了。那麼我們可以不可對於簡單問題只用較少層數來解答,對於複雜問題採用全部層數的 BERT 呢?作者在這裏對 BERT 的每層輸出後面都接了分類器,如果在淺層模型就有很高的置信度對樣本進行分類,那麼久不再走後面的層了,如果置信度不高,那就繼續走後面,這樣就極大地縮短了推理的時間。這裏每層的分類器並不是單獨訓練的,而是使用一個總的 teacher classifier 蒸餾出來的,從實驗結果可以看出採用自蒸餾要比直接訓練效果好。

二、模型詳解

BackBone

整個模型的骨架就是採用 BERT 來實現,所有 BERT 系列的模型都可以套用在這裏。BERT 在這裏的作用還是一個強大的特徵提取器的作用,在多層 transformer 堆疊的後面緊跟着一個 teacher classifier,這個在 fine-tune 階段會進行訓練,後面用來蒸餾每一層的 student classifier。整體的模型結構見下圖:


圖裏面的 Branch 就是每個 student classifier,它們具有和 teacher classifier 一樣的結構。在實際推理的時候,從底層開始往上,如果有很高的置信度就 early output,不再僅需往後面走了。

Model Training

模型的訓練一共包括三個部分,一個是主要骨架模型的預訓練,這裏和傳統的 BERT 模型是一樣的;然後就是整個骨架模型的 fine-tuning,這裏會訓練 teacher classifier;最後是對 teacher classifier 進行蒸餾的到 student classifier。預訓練和 fine-tuning 沒啥好說的,和 BERT 是一摸一樣的,這裏主要介紹一下 self-distillation。自蒸餾和傳統的蒸餾方式最大的不同就是 teacher 模型和 student 模型是一樣的在一個模型裏面,傳統的方式往往需要單獨設計 student 模型。見下圖:

但是自蒸餾的話就不存在這個問題了,自蒸餾使用 teacher classifier 的輸出 Pt 以及 student 的輸出 Ps,然後計算他們的 KL 散度,通過優化所有 student KL 散度 loss 的合來確保 student 和 teacher 的分佈越來越相似。具體公式如下:
DKL(ps,pt)=i=1Nps(i)logps(i)pt(j)D_{K L}\left(p_{s}, p_{t}\right)=\sum_{i=1}^{N} p_{s}(i) \cdot \log \frac{p_{s}(i)}{p_{t}(j)}
Loss(ps0,,psL2,pt)=i=0L2DKL(psi,pt)\operatorname{Loss}\left(p_{s_{0}}, \ldots, p_{s_{L-2}}, p_{t}\right)=\sum_{i=0}^{L-2} D_{K L}\left(p_{s_{i}}, p_{t}\right)

Adaptive Inference

在推理階段作者使用了自適應的推理,簡單來說就是每層的 BERT 都會有一個結果,作者定義了一個輸出結果不確定性的度量,用來衡量每層輸出的結果是否可信,公式如下:
Uncertainty=i=1Nps(i)logps(i)log1N\text {Uncertainty}=\frac{\sum_{i=1}^{N} p_{s}(i) \log p_{s}(i)}{\log \frac{1}{N}}
根據公式我們可以發現,這個不確定性就是用熵來衡量的。熵越大代表結果越不可信,如果某一層的不確定性小於一個閾值,那麼我們就對這層的結果進行輸出,從而提高了推理速度。我們可能會發現一個問題,如果在淺層的準確率很低,後面也就沒辦法了。所以個人感覺這個模型效果好的原因應該是基於一個假設:“數據集中大部分樣本都是簡單樣本”。這裏我有點想知道,如果只用 teacher 來整理一個最底層的 student,不知道效果會咋樣。

三、實驗

作者在 12 個數據集上對模型效果進行了對比,除了傳統的 BERT,也對比了 DistilBERT。結果如下:


從實驗結果中可以看出來,FastBERT 在提升速度的同時對於精度的減少還是比較小的。但是這個實驗美中不足的是在一些數據集上 FASTBERT 和 DistilBERT 並沒有壓縮到同等的數量級,這在一定程度上沒法真實的比較兩個模型的效果。

作者通過實驗證明了自蒸餾的效果,結果如下:

通過上圖可以發現,加入自蒸餾後,計算複雜度下降了非常多,但是 Acc 幾乎沒有下降。同時文中提到的假設:不確定性低,準確率高。作者也通過實驗證明了這一假設:


同時作者也給出了每層不確定性分佈的實驗,結果如下:


通過這個實驗結果結合不確定性低準確率高的假設,我們就可以發現,其實很多樣本在淺層就有很低的不確定性,所以完全沒有必要繼續走到後面的層去進行分類。

三、結論

作者通過充分利用 BERT 每層的輸出,發現很多簡單的樣本其實在淺層就已經可以很好的分類從而不需要走到最後,極大地提高了 BERT 的推理速度。文章的思路還是非常的新奇的,不是傳統的減少模型參數量來提高速度的方法,而是將訓練中的 early stop 思想用到了推理上。同時自蒸餾的想法也很有啓發性,避免了傳統蒸餾方式需要自己設計 student model 的情況。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章