參考資料
本文是下列資料的總結:
[1] 李宏毅視頻 59:36 開始
[2] Seq2Seq中Exposure Bias現象的淺析與對策
[3] Bridging the Gap between Training and Inferencefor Neural Machine Translation(2019ACL最佳長論文)
[4] Self-critical Sequence Training
原因
Seq2Seq模型會遇到常說的Exposure Bias現象,原因是在訓練階段和預測階段會遇到mismatch。訓練階段使用的是Teacher Forcing,也就是decoder在某時刻的輸入是上一時刻的ground truth(真實標籤)。然而在預測階段只能使用上一時刻decoder的輸出來作爲這一時刻的輸入,從而導致mismatch。
下圖來自資料 [1] 李宏毅老師的視頻:
在decoder的第二個時間步輸入了模型在第一個時間步的輸出B(預測錯誤),而不是在訓練階段能夠拿到的真實標籤A,下一個時間步就會到達沒有經過充分訓練或甚至完全沒有探索過的結點,也就是decoder在訓練階段在第二個時間步只擬合了條件分佈 而沒有擬合好條件分佈 。具體的例子可以見參考資料[2] 的“簡單例子”一節。
解決辦法
- Scheduled Sampling.
- 通過強化學習中的policy gradient直接優化BLEU
Scheduled Sampling
下圖來自資料[1],73分鐘左右。
也就是在第二個時間步開始,輸入有一定的概率 使用的是真實標籤reference,有 的概率使用的是模型在上一個時間步的輸出。而概率 隨着訓練的進行應該逐漸衰減至0,最後就是完全使用模型的輸出作爲輸入,這樣就與預測階段匹配了。
Sentence Level Oracle Word + Gumbel Noise
在參考資料 [3] 中,將採樣自模型輸出而作爲下一個時間步的輸入的詞稱爲 oracle word.
他們提出的方法與 Scheduled Sampling 的整體思路一致,只不過對於 oracle word 的選擇多了一些設計。
The oracle word should be a word similar to the ground truth or a synonym. Using different strategies will pro-duce a different oracle word.
論文作者認爲,oracle word應該與ground truth詞是同義詞或近義詞,然後論文給出了兩種得到 oracle word 的方案:
1、Word-Level Oracle Word,這個就是 Scheduled Sampling 使用的方案
2、Sentence-Level Oracle Word,使用 Beam-Search 先得到候選 decoder 輸出,然後根據所關注的指標(例如BLEU分數)來選出分數最高的輸出句子,將這個句子的單詞作爲每一步的 oracle word.
注意:使用 Beam-Search 得到的 decoder 輸出不一定和 Ground Truth 句子 等長,所以需要對 Beam-Search 過程做一些修正:如果某一步的最高概率的詞是結束符然而此時長度還不夠 ,就選概率第二高的詞;如果某一步產生完字符後長度就到達 了,然而這一步的概率最高詞不是結束符,就強制選擇結束符
除了提出新的 oracle word 的選擇方案,作者還對每一步採樣 oracle word 的過程使用了 Gumbel-Max 技巧,從而引入了 Gumbel Noise,增強魯棒性。
注:我本人對論文使用 Gumbel Max 的作用不確定是不是這樣理解,如果有更好的理解歡迎提出。我對於 Gumbel Max 的理解來自於 這篇博客,我自己也寫過一個Demo來演示Gumbel Max的作用,地址: 戳這裏
對抗訓練
參考資料 [2] 作者認爲,其實前面所述方案的原理在於給訓練階段引入了擾動,讓模型在有擾動的情況下依然可以預測正確。所以作者提出了兩種帶來擾動的方案:
1、啓發式的隨機替換。50%的概率不做改變;50%的概率把輸入序列中30%的詞替換掉,替換對象爲原目標序列的任意一個詞。
2、梯度懲罰
注:至於爲什麼梯度懲罰等價於對抗訓練,可以參考該作者的另一篇博客:對抗訓練淺談:意義、方法和思考(附Keras實現)
作者通過實驗說明了這兩種方法都有一定效果。
基於強化學習直接優化BLEU
見參考資料 [4],主要工作有MIXER 及其改進。本人對這部分沒有深入瞭解,以後有需要再進一步學習。