BERT 可解釋性-從“頭”說起


轉載來源:https://zhuanlan.zhihu.com/p/148729018
alt

一、背景介紹

搜索場景下用戶搜索的 query 和召回文章標題 (title) 的相關性對提升用戶的搜索體驗有很大幫助。query-title 分檔任務要求針對 query 和 title 按文本相關性進行 5 個檔位的分類(1~5 檔),各檔位從需求滿足及語義匹配這兩方面對 query-doc 的相關度進行衡量,檔位越大表示相關性越高,如 1 檔表示文本和語義完全不相關,而 5 檔表示文本和語義高度相關,完全符合 query 的需求。
alt

我們嘗試將 Bert 模型應用在 query-title 分檔任務上,將 query 和 title 作爲句對輸入到 bert 中,取最後一層 cls 向量用做 5 分類 (如上圖),最後得到的結果比 LSTM-Attention 交互式匹配模型要好。雖然知道了 bert 解決這個問題,我們更好奇的是 " 爲什麼 ":爲什麼 bert 的表現能這麼好?這裏面有沒有可解釋的部分呢?

因爲 Multi-head-attention 是 bert 的主要組成部分,所以我們從 “頭” 入手,希望弄清楚各個 head 對 bert 模型有什麼作用。爲了研究某個 head 對模型的影響,我們需要比較有這個 head 和沒有這個 head 模型的前後表現。這裏定義一下 HEAD-MASK 操作,其實就是針對某個 head,直接將這個 head 的 attention 值置成 0,這樣對於任何輸入這個 head 都只能輸出 0 向量。

通過 HEAD-MASK 操作對各個 head 進行對比實驗,發現了下面幾個有趣的點

  • attention-head 很冗餘 / 魯棒,去掉 20% 的 head 模型不受影響
  • 各層 transformer 之間不是串行關係,去掉一整層 attention-head 對下層影響不大
  • 各個 head 有固定的功能
    • 某些 head 負責分詞
    • 某些 head 提取語序關係
    • 某些 head 負責提取 query-title 之間 term 匹配關係

下面我們開始實驗正文,看看這些結論是怎麼得到的

二、Bert 模型 Attention-Head 實驗

attention-head 是 bert 的基本組成模塊,本次實驗想要研究各個 head 都對模型作出了什麼貢獻。通過 Mask 掉某個 head,對比模型前後表現的差異來研究這個 head 對模型有什麼樣的作用 (對訓練好的 bert 做 head-mask,不重新訓練,對比測試集的表現)。

bert-base 模型共 12 層每層有 12 個 head,下面實驗各個 head 提取的特徵是否有明顯的模式 (Bert 模型爲在 query-title 數據上 finetune 好的中文模型)

2.1 Attention-Head 比較冗餘

標準大小的 bert 一共有 12*12 共 144 個 head. 我們嘗試對訓練好的 bert 模型,隨機 mask 掉一定比例的 head, 再在測試數據集上測試分檔的準確率 (五分類)。

下圖的柱狀圖的數值表示相比於 bseline(也就是不做任何 head-mask) 模型 acc 的相對提升, 如 + 1% 表示比 baseline 模型的 acc 相對提高了 1%,從下面的圖可以看到,隨機 mask 掉低於 20% 的 head,在測試數據集上模型的 acc 不會降低,甚至當 mask 掉 10% 的 head 的時候模型表現比不做 head mask 的時候還提升了 1%。當 mask 掉超過一定數量的 head 後,模型表現持續下降,mask 掉越多表現越差。
alt

同時爲了弄清楚底層和高層的 transformer 哪個對於 query-title 分類更加的重要,分別對底層 (layer0 ~ layer5) 和高層 (layer6~layer11) 的 head 做 mask, 去掉的 head 比例控制在 0~50%(佔總 head 數量)之間,50% 表示去掉了底層或者是高層 100% 的 head 下面的圖很清晰的說明了底層和高層的 attention-head 關係,橙色部分表示只 mask 掉高層 (6 - 11 層) 的 head, 藍色部分表示只 mask 掉底層 (0 - 5 層) 的 head。

顯然高層的 attention-head 非常的依賴底層的 head,底層的 attention-head 負責提取輸入文本的各種特徵,而高層的 attention 負責將這些特徵結合起來。具體表現在當 mask 掉底層 (0~5 層) 的 80% 的 head(圖中橫座標爲 40%)和 mask 掉底層的 100% 的 head(圖中橫座標爲 50%)時,模型在測試數據集上表現下降劇烈(圖中藍色部分),說明了去掉大部分的底層 head 後只依賴高層的 head 是不行的,高層的 head 並沒有提取輸入的特徵。相反去掉大部分高層的 head 後模型下降的並沒有那麼劇烈(圖中橙色部分),說明了底層的 head 提取到了很多對於本任務有用的輸入特徵,這部分特徵通過殘差連接可以直接傳導到最後一層用做分類。
alt

這個結論後面也可以用於指導模型蒸餾,實驗結果表明底層的 transformer 比高層的 transformer 更加的重要,顯然我們在蒸餾模型時需要保留更多的底層的 head

那麼對於模型來說是否有某些層的 head 特別能影響 query-title 分類呢?假設將 bert 中所有的 attention-head 看做一個 12*12 的方陣,下面是按行 mask 掉一整行 head 後模型在測試數據上的表現,柱狀圖上的數值表示相比 baseline 模型的相對提升。
alt

可以看到 mask 掉第 5 層~第 9 層的 head 都模型都有比較大的正面提升,特別是當去掉整個第 8 層的 attention-head 的時候測試數據準確率相對提升了 2.3%,從上圖可以得到兩個結論:

  • Bert 模型非常的健壯或者是冗餘度很高
  • Bert 模型各層之間不是串行依賴的關係,信息並不是通過一層一層 transformer 層來傳遞的

bert 模型非常的健壯或者是冗餘度很高,直接去掉一整層的 attention-head 並不會對模型的最終表現有太大的影響。 直接去掉整層的 attention-head 模型表現並沒有大幅度的下降,說明各層提取的特徵信息並不是一層一層的串行傳遞到分類器的,而是通過殘差連接直接傳導到對應的層。

2.2 某些 head 負責判斷詞的邊界 (使得字模型帶有分詞信息)

在我們的 query-title 分檔場景中,發現詞粒度的 bert 和字粒度的 bert 最終的表現是差不多的,而對於 rnn 模型來說字粒度的 rnn 很難達到詞粒度 rnn 的效果,我們希望研究一下爲什麼詞粒度和字粒度的 bert 表現差不多。

使用的 bert 可視化工具 bert_viz [2][2]觀察各層 attention-head 的 attention 權重分佈,可以發現某些 head 帶有很明顯的分詞信息。推測這部分 attention-head 是專門用於提取分詞信息的 head。噹噹前的字可能是詞的結尾時,att 權重會偏向 sep, 當這個字爲詞的結尾可能性越大 (常見的詞結尾),sep 的權重會越高。噹噹前字不是詞結尾時,att 會指向下一個字。這種模式非常明顯,直接拿這個 attention-head 的結果用於分詞準確率爲 70%。

下面 gif 爲我們模型中第 1 層第 3 個 head 的 attention 分佈權重圖,可以發現 attention 權重很明顯帶有詞的邊界信息,噹噹前的字是結尾時 attention 權重最大的 token 爲 “SEP”,若當前字不是結尾時 attention 權重最大的爲下一個字。
alt

這種用於提取分詞信息的 head 有很多,且不同的 head 有不同的分詞粒度,如果將多個粒度的分詞綜合考慮 (有一個 head 分詞正確就行),則直接用 attention-head 切詞的準確率在 96%,這也是爲什麼詞粒度 bert 和字粒度 bert 表現差不多的原因。
alt

猜測字粒度 bert 帶詞邊界信息是通過 bert 的預訓練任務 MLM 帶來的,語言模型的訓練使得 bert 對各個字之間的組合非常的敏感,從而能夠區分詞的邊界信息。

2.3 某些 head 負責編碼輸入的順序

我們知道 bert 的輸入爲 token_emb+pos_emb+seg_type_emb 這三個部分相加而成,而文本輸入的順序完全是用 pos_emb 來隱式的表達。bert 中某些 head 實際上負責提取輸入中的位置信息。這種 attention-head 有明顯的上下對齊的模式,如下圖:
alt

原輸入: query=“京東小哥”, title=“京東小哥最近在幹嘛”,bert 模型判定爲 4 檔

將 title 順序打亂: query=“京東小哥”, title=“近東嘛最都在乾哥小京”,bert 模型判定爲 2 檔 將 title 順序打亂: query=“京東小哥”, title=“近東嘛最都在乾哥小京”,mask 掉 7 個懷疑用於提取語序的 head,bert 模型判定爲 3 檔

下面的圖分別對比了不做 mask,隨機 mask 掉 7 個 head(重複 100 次取平均值),mask 掉 7 個特定的 head(懷疑帶有語序信息的 head) 從下面的圖看到,mask 掉 7 個特定的 head 後整體分檔提升爲 3 檔,而隨機 mask 掉 7 個 head 結果仍然爲 2 檔,且檔位概率分佈和不 mask 的情況差別不大。

這個 case 說明了我們 mask 掉的 7 個特定的 head 應該是負責提取輸入的順序信息,也就是語序信息。將這部分 head mask 掉後,bert 表現比較難察覺到 title 中的亂序,從而提升了分檔。
alt

2.4 某些 head 負責 query 和 title 中相同部分的 term 匹配

query 和 title 中是否有相同的 term 是我們的分類任務中非常關鍵的特徵,假如 query 中大部分 term 都能在 title 中找到,則 query 和 title 相關性一般比較高。如 query=“京東小哥” 就能完全在 title=“京東小哥最近在幹嘛” 中找到,兩者的文本相關性也很高。我們發現部分 attention-head 負責提取這種 term 匹配特徵,這種 head 的 attention 權重分佈一般如下圖,可以看到上句和下句中相同 term 的權重很高 (顏色越深表示權重越大)。
alt

其中在第 2~ 第 4 層有 5 個 head 匹配的模式特別明顯。我們發現雖然 bert 模型中 attention-head 很冗餘,去掉一些 head 對模型不會有太大的影響,但是有少部分 head 對模型非常重要,下面展示這 5 個 head 對模型的影響,表格中的數值表示與 baseline 模型的 acc 相對提升值。
alt

利用測試數據作爲標準,分別測試隨機 mask 掉 5 個 head 和 mask 掉 5 個指定的 head(這些 head 在 attention 可視化上都有明顯的 query-title 匹配的模式)。從結果可以看到去掉這些負責 query-title 匹配的 head 後模型表現劇烈下降,只去掉這 5 個 head 就能讓模型表現下降 50%。甚至 mask 掉 0~5 層其他 head,只保留這 5 個 head 時模型仍維持 baseline 模型 82% 的表現,說明了 query-title 的 term 匹配在我們的任務中是非常重要的。

這也許是爲什麼雙塔 bert 在我們的場景下表現會那麼差的原因 (Bert+LSTM 實驗中兩個模型結合最後的表現差於只使用 Bert, Bert 的輸入爲雙塔輸入),因爲 query 和 title 分別輸入,使得這些 head 沒有辦法提取 term 的匹配特徵 (相當於 mask 掉了這些 head),而這些匹配特徵對於我們的分類任務是至關重要的

2.4.1 finetune 對於負責 term 匹配 attention-head 的影響

在 query-title 分檔任務中 query 和 title 中是否有相同的 term 是很重要的特徵,那麼在 finetune 過程中負責 query-title 中相同 term 匹配的 head 是否有比較明顯的增強呢?

下面以 case 爲例說明: query=“我在伊朗長大” title=“假期電影《我在伊朗長大》”

下圖展示了 query-title 數據 ***finetune 前 ***** 某個 ** 負責 term 匹配的 head 的 attention 分配圖
alt

沒有 finetune 前,可以看到某些 head 也會對上下句中重複的 term 分配比較大的 attention 值,這個特質可能是來自預訓練任務 NSP(上下句預測)。因爲假如上句和下句有出現相同的 term,則它們是上下句的概率比較大,所以 bert 有一些 head 專門負責提取這種匹配的信息。

除了上下句相同的 term 有比較大的注意力,每個 term 對自身也有比較大的注意力權重(體現在圖中對角線上的值都比較大) 爲了更直觀的看訓練前後哪部分的 attention 值有比較大的改變,分別展示訓練後 attention 增強 (微調前 - 微調後> 0) 和訓練後 attention 減弱 (微調前 - 微調後 < 0) 的 attention 分配圖。可以觀察到比較明顯的幾個點:

  • query 和 title 中 term 匹配的 attention 值變大了 從下圖可以看到, query 和 title 中具有相同 term 時 attention 相比於訓練前是有比較大的增強。說明在下游任務 (query-title 分檔) 訓練中增強了這個 head 的相同 term 匹配信息的抽取能力。
    alt

  • term 和自身的 attention 變小了 模型將重點放在找 query 和 title 中是否有相同的 term,弱化了 term 對自身的注意力權重
    alt

  • 分隔符 sep 的 attention 值變小了。論文指出當某個 token 的 attention 指向 sep 時表示一種不分配的狀態 (即此時沒有找到合適的 attention 分配方式),在經過 finetune 後 term 指向 sep 的權重變小了,表示經過 query-title 數據訓練後這個 head 的 attention 分配更加的明確了。

2.4.2 是否有某個 head 特別能影響模型

從上面的實驗可以看到,bert 模型有比較多冗餘的 head。去掉一部分這些 head 並不太影響模型,但是有少部分 head 特別能影響模型如上面提到的負責提取上下句中 term 匹配信息的 head,只去掉 5 個這種 head 就能讓模型的表現下降 50%。那麼是否有某個 head 特別能影響結果呢?

下面實驗每次只 mask 掉一個 head,看模型在測試數據中表現是否上升 / 下降。下圖中將 bert 的 144 個 head 看作 12X12 的矩陣,矩陣內每個元素表示去掉這個 head 後模型在測試數據上的表現。其中 0 表示去掉後對模型的影響不太大。元素內的值表示相對於 baseline 的表現提升,如 + 1% 表示相比 baseline 的 acc 提高了 1%。
alt

可以看到對於 bert 的大部分 head,單獨去掉這個 head 對模型並不會造成太大的影響,而有少部分 head 確實特別能影響模型,比如負責上下句 (query-title) 中相同 term 匹配的 head。即使去掉一個這種 head 也會使得模型的表現下降。同時注意到高層 (第 10 層) 有一個 head 去掉後模型表現變化也很大,實驗發現這個 head 功能是負責抽取底層 head 輸出的特徵,也就是 3-4 層中 head 抽取到輸入的 query-title 有哪些相同 term 特徵後,這部分信息會傳遞到第 10 層進一步進行提取,最後影響分類。

2.4.3 高層 head 是如何提取底層 head 特徵 - 一個典型 case

上圖中,在第 10 層有一個 head 去掉後特別能影響模型,觀察其 attention 的分佈,cls 的 attention 都集中在 query 和 title 中相同的 term 上,似乎是在對底層 term 匹配 head 抽取到的特徵進一步的提取,將這種匹配特徵保存到 cls 中 (cls 最後一層會用於分類)。
alt

在沒有做任何 head-mask 時, 可以看到 cls 的 attention 主要分配給和 query 和 title 中的共同 term “紫熨斗”,而 mask 掉 5 個 2~4 層的 head(具有 term 匹配功能) 時, 第 10 層的 cls 注意力分配明顯被改變,分散到更多的 term 中。
alt

這個 case 展示了高層 attention-head 是如何依賴底層的 head 的特徵,進一步提取底層的特徵並最後作爲重要特徵用於 query-title 分類。

結語

本文主要探討了在 query-title 分類場景下, bert 模型的可解釋性。主要從 attention-head 角度入手,發現 attention 一方面非常的冗餘,去掉一部分 head 其實不會對模型造成多大的影響。另外一方面有一些 head 卻非常的能影響模型,即使去掉一個都能讓模型表現變差不少。同時發現不同的 head 實際上有特定的功能,比如底層的 head 負責對輸入進行特徵提取,如分詞、提取輸入的語序關係、提取 query 和 title(也就是上下句) 中相同的 term 信息等。這部分底層的 head 提取到的特徵會通過殘差連接送到高層的 head 中,高層 head 會對這部分特徵信息進行進一步融合,最終作爲分類特徵輸入到分類器中。

本文重點討論了哪些 head 是對模型有正面作用,也就是去掉這些 head 後模型表現變差了。但是如果知道了哪些 head 爲什麼對模型有負面作用,也就是爲什麼去掉某些 head 模型效果會更好,實際上對於我們有更多的指導作用。這部分信息能夠幫助我們在模型加速,提升模型表現上少走彎路。

參考文獻

[1] Clark K, Khandelwal U, Levy O, et al. What Does BERT Look At? An Analysis of BERT’s Attention[J]. arXiv preprint arXiv:1906.04341, 2019.

[2] Vig J. A multiscale visualization of attention in the transformer model[J]. arXiv preprint arXiv:1906.05714, 2019.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章