解決的問題
Transformer的自注意力機制可以讓長距離的單詞直接聯繫,可以很容易地學習到句子之間的長距離依賴。但是在將Transformer應用在語言模型時,核心的問題在於如何將任意長度的context編碼成固定長度的上下文變量。
普遍的做法是將整個語料庫劃分成較短的片段,在每個片段上訓練模型。但是這麼做很有幾個問題:
- 最大可能依賴長度不會超過片段的長度
- 語料庫按照固定長度而不是按照語義或者句子的分界劃分進行分片,導致分片之間無法共享信息。這樣情況下訓練的語言模型會丟失一部分上下文信息,導致收斂速度和性能不符合預期。論文把這個問題成爲上下文碎片問題。
貢獻
- 在Transformer的基礎上提出segment-level recurrence mechanism、新的位置編碼模式兩種方法使得Transformer-XL能夠在不破壞時間一致性的情況下,學習固定長度之外依賴關係並且解決上下文碎片問題
結果
- 學習依賴的範圍比RNN(LSTM)高出80%,比Transformer高出450%,
- evaluation過程速度比Transformer提高1800倍。
- 在5個語言模型數據集上取得SOAT成績。
Recurrent mechanism機制
前一個分片的隱藏狀態被保存下來作爲計算下一個分片隱藏狀態的擴展輸入。這種方法能夠在計算當前分片隱藏狀態的時候使用到上文的信息,進而可以有效的避免上下文碎片並且增加依賴的長度。具體來說,前一個分片第n-1層的隱藏狀態被保存下來,與當前分片的第n-1層的隱藏狀態一起作用生成當前分片第n層的隱藏狀態。這種“循環”利用上個分片的隱藏狀態的方法能夠有效加長期依賴的範圍。
其中,下標分別表示分片的編號,上標表示當前計算的隱藏狀態在模型中的層數,分別表示計算注意力需要的查詢(query)、鍵(key)和值(value)向量,表示模型的參數矩陣。首先第個分片第n-1層的隱藏狀態被保留下來,做停止更新梯度的操作之後和第個分片第n-1層的隱藏狀態做拼接得到融合了上文信息的新狀態,新狀態被用來計算key和value向量,query向量則使用原始的隱藏狀態計算。最後三個向量作爲Transformer的輸入計算第分片n層的隱藏狀態。
相對位置編碼(Relative Position Encoding)
從上面介紹的Recurrent機制中,我們知道當前分片的隱藏狀態融合了前一個分片的信息,問題在於不同的分片的隱藏狀態計算過程中使用到的初始輸入的位置編碼是相同的。這樣模型無法根據位置編碼區別不同分片中相同位置的輸入單詞。
解決的辦法是在每一層的隱藏狀態計算時都加入相對位置編碼。實際上,位置編碼的作用是在計算注意力的時候提供時序的線索以組合分片中不同單詞的信息,在相對位置編碼的設定下,正弦編碼矩陣的每一行表示相對位置差值爲時的相對位置編碼。在注意力計算過程中使用相對位置編碼,使得當前分片的查詢向量可以卻分不同分片的相同位置的輸入和。
上圖是transformer的一個注意力頭使用絕對位置編碼的注意力得分你的公式,可以拆分爲四個部分。使用相對位置編碼替換絕對位置編碼分爲三步,第一步將(b)、(d)兩項末尾的絕對位置編碼替換成相對位置編碼<u,矩陣R是一個正弦編碼矩陣,沒有科學系的參數;第二步引入可訓練的參數代替©項的,對於任何的query位置,query向量都是一樣的,也就是對不同單詞的注意力傾向是相同的,同理(d)像用另一個可學習參數代替;第三步,分離兩個權重矩陣和,分別產生基於內容的key向量和基於位置的key向量。
Attention head的計算過程
整個計算過程其實就是將上面兩個部分連在一起,首先時候分段重用機制融合之前分段信息計算得到查詢、鍵、值向量,然後使用相對位置編碼替換絕對位置編碼計算注意力得分,最後經過前向網絡得到當前分段的新的隱藏狀態,輸出給下一層的Transformer模塊。
Evaluation速度巨大提高
評估的時候,Transformer每次預測一個新詞都需要將其前L個詞作爲一個分片,以保證每次預測都能使用到最長的上文的信息。但是這種預測方式需要很大的計算力。Recurrent機制能夠極大程度的加速了evaluation的過程,因爲保留了前一個分片的隱藏狀態,而不是每次預測新詞都從頭開始計算,從而達到了加速的效果。
實驗結果
Transformer-XL在五個數據集上語言模型的數據集上取得了SOAT成績,包括character-level的WikiText-103、enwiki8、text8和One Billion Word四個數據集和word-level的Penn Treebank數據集。
WikiText-103數據集包含28K篇文章的103M個字符,平均每篇文章包含3.6K個字符。這個數據集可以有效的評估模型處理長期依賴的能力。
enwik8包含了100MB未處理的Wikipedia的文本。與enwiki8相似,text8同樣包含了100MB的Wikipedia文本,區別在於移除了26個字母和空格以外的其他字符。
One Billion Word數據集,顧名思義包含有10億個單詞,但是該數據集將所有的句子的順序進行打亂,因而沒有保留句子與句子之間的長期依賴。但是transformer-XL依然在這個數據集上超過傳統Transormer,說明transformer-XL在處理短句子時的泛化能力。
Penn Treebank只包含1M個字符,作者認爲在該數據集上取得SOAT結果說明transformer-XL對於小數據集也具有良好的泛化能力。