TransformerXL解讀

背景

對語言模型建模,RNN和Transformer都是能提取長距離的依賴關係的特徵提取器。RNN方面,由於本身的recurrent機制,可以接受任意長度的序列作爲輸入,但是由於梯度消失和爆炸(gradient vanishing and explosion)和無法並行計算等問題,實際效果不佳;Transformer作爲新貴,雖然不存在上述問題,但是由於實際不可能輸入任意長度的詞encoding到fixed length,只能先按某個固定最大長度分chunks再對每個chunks計算,這就帶來了兩個問題,即模型無法建立chunks之間的依賴關係(對長文本處理不好)和因爲邊界問題對開頭的幾個單詞預測的不好(context fragmentation).

對此TransformerXL解決了Vanilla Transformer存在的這些問題(XL的意思的extra long,針對超長文本)。論文參考《Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context》

Transformer回顧

在這裏插入圖片描述
構建語言模型通常是用t時刻之前的序列預測t時刻詞的概率,通過模型將t時刻之前的輸入encode到固定長度的representation, 然後經過線性變換後的softmax得到t時刻詞的概率。
而基於傳統的transformer即vinilla transformer存在的問題主要在於如何將任意長度的sequence編碼成固定長度的representation。實際由於算力問題,有一種解決辦法是將長序列的文本劃分爲多個chunks, 每個chunk單獨訓練模型忽略不同chunk之間的依賴關係。論文見《Character-Level Language Modeling with Deeper Self-Attention》
預測的時候,每次和訓練的時候chunk一樣輸入固定長度的序列,但是隻預測下一個token, 之後shift右移一位當作新的chunk,重新計算下一個token,這種方式較爲費力。

TransformerXL模型架構

Segment-Level Recurrence with State Reuse

爲了解決vinilla transformer的問題,這裏提出了循環機制。

During training, the hidden state sequence computed for the previous segment is fixed and cached to be reused as an extended context when the model processes the next new segment。

訓練的時候TransformerXL對每個segment的hidden state保留cache作爲下一個segment的輸入,這樣就把不同segment的長距離依賴關係進行捕捉。
在這裏插入圖片描述
其中h表示hiddent state,n表示第n層transformer,t表示第t個segment, SG表示stop gradient,記不算上一個segment的梯度;計算公式可以看出,和vinilla transformerr相比,區別在於計算k和v的時候,是利用上一個segment的hidden state和當前segment的hidden state進行concat之後的結果,這樣就能捕捉更長的依賴關係了。由於當前層的hidden state是由下一層的包含當前時刻和前L-1個state計算出來的,依次類推,最長依賴關係正比於O(N × L),N爲segment的總個數, L爲每個segment的固定長度,通常L>N 。
在這裏插入圖片描述
另外出了上述的好處外,在預測的時候,由於緩存了之前的hidden state,再計算預測之後的token的時候不需要重新計算,比vinilla transformer快上千倍。

Relative Positional Encodings

上述循環機制還有個問題沒有解決,就是transformer的position encoding只在segment中由絕對位置編碼,卻沒有跨越segment的相對位置編碼,這樣模型無法區分不同segment的相同位置的區別。
傳統的transformer的attention計算公式爲(Exi+Ui)WqWk(Exj+Uj)(E_{x_i}+U_i)W_qW_k(E_{x_j}+U_j), 其中WqW_qWkW_k爲key和value對應的的矩陣,E爲token的embedding,U爲position encoding,展開如下
在這裏插入圖片描述
TransformerXL對此做了如下改變在這裏插入圖片描述

  1. 將計算key向量的絕對位置編碼UjU_j換成了RijR_{i-j}, 這是一個sinusoid encoding matrix,不是學習得到的
  2. 將query向量的絕對位置向量替換成了可訓練的向量u和v,這是因爲這裏採用相對位置編碼,i位置的絕對編碼沒有意義
  3. WkW_k替換成了分別基於位置(location-based)和內容(content-based)的矩陣,計算得到不同的key向量

總的來說,Relative Positional Encodings就是在計算注意力分數時,用相對位置RijR_{i-j}和學習了的相對位置vvuu向量來代替絕對位置編碼UiU_iUjU_j

改造後的TransformerXL公式爲:
在這裏插入圖片描述

總結

TransformerXL通過循環機制利用上個segment的信息,並且將絕對位置編碼改成相對位置編碼,解決了普通Transformer無法建立超過固定長度文本的長依賴問題和context fragmentation,在預測的效率也大幅提升。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章