Transformer與Transformer-XL

回顧Transformer

在NLP領域,對語言建模最常用的模型就是RNNS(包括LSTM),但是RNNS存在一些問題,比如學習長期依賴的能力很弱(LSTM語言模型平均只能建模200個上下文詞語),而且學習速度也很慢。

在2017年,谷歌的一位學者提出了Transformer架構,其示意圖如下圖所示:Transformer不懂的可以看博客圖解Transformer
在這裏插入圖片描述
雖然Transformer相比LSTM可以建模更長的序列,但是也需要對輸入序列設置一個固定的長度(比如BERT中默認長度是512)。如果輸入序列長度小於固定長度可以通過填充的方式來解決,如果序列長度大於固定長度,常用的做法是將序列切割成多個segments,切割的時候並沒有考慮句子的自然邊界,而是根據固定長度來劃分序列,在訓練的時候每個segment單獨訓練,並沒有考慮相鄰的segment之間的上下文信息,所以segment之間的語義是不完整的,如下圖(a)所示:
在這裏插入圖片描述
在圖(a)中,輸入序列是[x1,,,x8][x_1,,,x_8],固定長度是4,將輸入序列分成2個segment進行訓練,而且每個segment都要從頭開始訓練,segment2並沒有利用到segment1的上下文信息,這種現象稱爲上下文碎片(context fragmentation)。

在預測的時候,會對固定長度的segment做計算,一般取最後一個位置的隱向量作爲輸出。爲了緩解context fragmentation,在每做完一次預測之後,就對整個序列向右移動一個位置,再做一次計算,如上圖(b)所示,這導致計算效率非常低。

Transformer-XL

爲了解決上述的問題,在Transformer的基礎上,提出了Transformer-XL,有兩點創新:Segment-Level Recurrence和Relative Position Encodings

segment level recurrence

在對當前segment進行處理的時候,緩存並利用上一個segment中所有layer的隱向量序列,而且上一個segment的所有隱向量序列只參與前向計算,不再進行反向傳播,這就是所謂的segment-level Recurrence。
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
訓練和預測過程如上圖所示。
在圖(a)中,與Transformer不同的是綠色線的部分,綠色線表示了上一個segment傳遞給當前segment的上下文信息,當前segment的第n層的每個隱向量的計算,都依賴於第n-1層的當前位置的隱向量和前L-1個隱向量。
在圖(b)中,當前segment(除了第一個segment)中的Transnformer-XL的第nn層的每個節點都依賴前面(n1)(L1)(n-1)(L-1)個token,所以最後一層的節點依賴的token最多。n通常要比L小很多,比如在BERT中,N=12或者24,L=512,依賴關係長度可以近似爲 O(NL)O(N*L) 。在對長文本進行計算的時候,可以緩存上一個segment的隱向量的結果,不必重複計算,大幅提高計算效率。

Relative Position Encodings

在vanilla Trm中,爲了表示序列中token的順序關係,在模型的輸入端,對每個token的輸入embedding,加一個位置embedding。位置編碼embedding或者採用正弦\餘弦函數來生成,或者通過學習得到。在Trm-XL中,這種方法行不通,每個segment都添加相同的位置編碼,多個segments之間無法區分位置關係。Trm-XL放棄使用絕對位置編碼,而是採用相對位置編碼,在計算當前位置隱向量的時候,考慮與之依賴token的相對位置關係。具體操作是,在算attention score的時候,只考慮query向量與key向量的相對位置關係,並且將這種相對位置關係,加入到每一層Trm的attention的計算中。
在這裏插入圖片描述

整體計算公式

在這裏插入圖片描述
總結,Trm-XL爲了解決長序列的問題,對上一個segment做了緩存,可供當前segment使用,但是也帶來了位置關係問題,爲了解決位置問題,又打了個補丁,引入了相對位置編碼。
參考:
Transformer-XL介紹
Transformer-XL解讀(論文 + PyTorch源碼)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章