0. 背景
機構:Google Research 、U.C. Berkeley
作者:Nikita Kitaev、Łukasz Kaiser、Anselm Levskaya
論文地址:https://arxiv.org/abs/2001.04451
收錄會議:ICLR2020
論文代碼:https://github.com/google/trax/tree/master/trax/models/reformer
0.1 摘要
基於Transformer的各種巨型模型在各種自然語言處理任務中常常能夠取得最優結果,但這些模型的訓練成本往往過高,在針對長序列文本上尤甚。爲此,本文提出兩種技術以改善基於Transformer的這類模型,名爲Reformer。第一,使用局部敏感hash,替換原始的點乘方式的attention,從而將其空間複雜度從降低到,其中表示文本序列的長度。第二,使用逆殘差層代替標準的殘差,這使得訓練過程中只需存儲一次激活值,而無需次,其中表示網絡層數。最終的結果表明Reformer性能與Transformer相當,同時在長序列上具有更高的內存效率和更快的速度。
1. 介紹
先看看Transformer模型是否真的那麼佔用資源或者說低效。以現有的最大Transformer層爲例,該Transformer層中參數量是0.5B,這需要2GB的內存。(1M=1024KB,1KB=1024Byte。所以1GB=1024M=1024x1024KB=1024x1024x1024Byte=1073741824Byte。float佔用4個Byte。0.5B即5億參數,需要的內存量爲5億*4字節=20億字節。這差不多是1.86GB即約爲2GB)對於由64Ktokens組成的序列,如果嵌入層的尺寸是1024,batch size是8,那麼激活值需要個浮點數來存儲,這又需要2GB的內存。如果每層的內存佔用只有上述提到的這些的話,那麼在單加速器上使用Transformer處理64K長度的序列也是輕而易舉。此外,如此前提下訓練BERT的整個語料庫也只需17GB的內存。然而,現實並非如此,真實環境下爲何甚至不能在單臺機器上對這些模型進行微調呢?
這是因爲上述僅僅考慮單層參數的內存佔用和輸入激活值的內存消耗,而忽略了 Transformer 在內存佔用上的主要問題:
-
長度爲的序列的 attention 的時間和空間複雜度是 ,所以對於 64K tokens的序列就會耗盡內存。
-
需要存儲激活值用於反向傳播,那麼N層模型內存佔用是單層的N倍;
-
由於中間全連接層的深度通常遠大於注意力激活層的深度,而這需要佔用很大的內存;
爲此,本文提出Reformer模型以解決上述問題,具體採用如下方案:
-
採用基於局部敏感哈希(locality-sensitive hashing,LSH)的近似注意力計算,讓注意力層的 因子變爲 ,這使得在長序列上的處理成爲可能。
-
可逆層(Reversible layer),在整個模型中只使用單個副本,所以可以消除層數因子;
-
在前饋層(feed-forward layer)分開激活和分塊處理,從而消除因子的影響,降低前饋層的內存佔用;
Reformer模型在以下3個任務上進行實驗:合成任務、文本任務(enwik8,序列長度爲64K)和圖像生成任務(imagenet-64,序列長度爲12K)。實驗結果表明Reformer結果與Transformer相當,但是更快、內存也更高效。
2. 局部敏感哈希Attention
點乘attention:
標準的Transformer使用點乘的attention,queries和keys的維度都是,values的維度是。query先與key做點乘,再除以,再輸入到softmax中得到value的權重,最後權重再與value相乘,得到最終的結果。在實際操作過程中是以矩陣方式進行批量操作,queries組成矩陣,keys組成矩陣,values組成矩陣,上述流程概況如下:
多頭attention:
上述的attention操作並行地進行h次,再輸出維度爲的輸出結果。再將這些結果拼接,再做一次投射操作得到最終的結果。即所謂的多頭attention。
高效內存attention:
先來算下上述attention機制消耗的內存。假設的尺寸爲[batch_size,length,d_model]。的尺寸爲[batch_size,length,length]。實驗中發現,當length=64k,即使batch_size=1,那麼64k*64k大小的矩陣,如果用32位浮點數來存儲的話,需要16GB內存。鑑於此,在長序列上使用Transformer顯得不切實際。但是需要注意的是,矩陣可以不必全部放在內存中,可以對每個query分別計算attention,那麼只需要在內存計算。反向傳播計算梯度時再重新計算一次。這種方式計算attention雖然低效,但是所佔用的內存與length成正比。這種方法在本文這裏作爲一種全attention的baseline。
Q,K,V從何處來?
上述討論了Q、K、V,但是一般我們只會得到大小爲[batch_size,length,d_model]的激活值,這些值是token的嵌入所組成的句向量。那麼爲了從A中得到Q、K、V,Transformer使用了3個不同的線性層(參數不同)將A投射爲Q、K、V。對於使用局部敏感哈希attention的模型,我們希望queries和keys(即Q和K)相同。只需要A投射到Q和A投射到K時採用相同線性變換參數即可,而A投射到V時採用不同參數。這種方式成爲共享QK-Transformer。實驗表明共享QK並不會影響Transformer的性能,即使添加一項的歸一化項。
Hashing attention:
在LSH attention中,假設Q、K、V的尺寸爲[batch_size,length,d_model],同時仍然使用此前介紹的多頭attention機制。那麼的尺寸爲[batch_size,length,length]。由於的計算結果主要取決於值最大的部分,對於每個query只需關注中與query最接近的點。當的長度是64k,那麼對個每個query,本文僅僅考慮其最近的的32或64個keys。如此會更加高效,那麼如何找尋最近的那些keys呢?
局部敏感哈希(LSH):
在高緯空間中找尋最近鄰可以使用局部敏感哈希(LSH)。將每個向量x通過hash函數h(x)進行映射,如果近處的向量獲得相同的hash,且具有高概率,而遠處的向量沒有,那麼這樣的hash稱爲位置敏感型hash。在此處例子中,我們實際上只要求近鄰的向量以高概率具有相同的hash值,並且hash桶也以高概率具有相同的大小。
具體是使用如Figure 1所示的隨機投射方法:
上圖的angular LSH是常用LSH算法的一個變體,它將點投射到一個單位球上,這個單位球被劃分爲預定義的區域,每個區域都有一個特定的代碼。然後一系列隨機旋轉的點定義了這些點所歸屬的桶。讓我們通過一個簡單的2D例子來說明這一點,
angular LSH的動圖說明,圖片來源
這裏有兩個點,它們投影到一個單位圓上,並以不同的角度隨機旋轉3次。可以觀察到,它們不太可能共享同一個hash桶。在後續例子中,可以看到兩個非常接近的點在3次隨機旋轉後會位於相同的hash桶:
Angular LSH最近鄰搜索的的一個簡化動畫:兩個點很接近的情況。圖片來源
如果想要得到b個hash,那麼先固定一個隨機矩陣R的大小爲。再定義,其中表示兩個向量之間的拼接。
LSH attention:
綜合考慮上述的LSH策略和hashing attention,先重寫單個query在位置i的常規attention:
其中表示query在位置i所需要attend的集合,表示配分函數(partition function)比如softmax中的歸一化項。爲了書寫清楚,這裏省略了縮放項。
對於批量操作,定義批量操作集合,當遮蔽的元素不在中,此時常規attention定義如下:
即對於不能attend到的位置,爲正無窮,那麼減去正無窮再去exp操作,其結果爲0。這樣就不需要對於每個位置i都有單獨的
在LSH attention中,query中位置i所能夠attend的限制集合被限制到一個hash桶中:
Figure 2(a-b)展示的是全attention和hash attention的對比。
圖a:常規的attention機制中,黑點代表的是softmax中佔主導的位置。注意這邊的attention使用的是encoder的attention, 否則 無法attend to 。另外,這種全attention(即encoder中的attention)的attention矩陣一般是稀疏的,但計算中並沒有利用這種稀疏性,所以可以利用這個降低時間空間複雜度。
圖b:計算query和key所歸屬的hash桶。再按照桶進行排序,同一個桶又按照原本的位置進行排序得到圖b。可以看到,同一個桶,可以出現多個query但keys很少的情況,例如圖中藍色的桶query有3個,都attend到同一個key中。由於相似的item很有可能落在同一個桶裏,所以只在每個桶內部進行attention就可以近似全attention。
圖c:爲了減緩桶中q和k不均衡問題,本文通過令使得,即使用了share-QK attention。然後先按照桶序號對queries排序,每個桶中,仍按照原本的position 位置大小排序。得到圖c。對比b圖和c圖可以看出,縱軸的k已經變成了q。時候就能保證對角線都是attend 到的而且q和k在桶中的個數一樣(因爲Q=K)。排序後的attention矩陣,相同桶的值會在對角線附近聚集。注意到圖中對角線的點爲空心,這是因爲雖然在正常情況下,q會attend to本身位置的value,但是在share-QK的實現下,如果attend to本身,會導致其值特別大,其他的值特別小,經過softmax之後,其他都是0,就自己本身是1。所以爲了避免這種情況,q不會去attend 自身位置的值,除非只有自己本身可以attend。
圖d:即使Q=K,還是會出現一個問題:有的桶中個數多,有的桶中個數少。比如一個極端情況,2個桶,其中一個桶佔據了所有的keys,另一個桶爲空,那麼LSH attention就沒有起作用。於是在圖c的基礎上,增加了chunk的操作。對輸入進行排序之後(即圖c中先桶排序,同個桶內按照token 的 position排序)得到新的序列順序,比如圖中原來的序列順序是,新的序列順序是 。每個chunk內query的上限個數爲, ( 爲輸入query的長度) ,每個桶平均大小爲,這裏假設桶中數量增加到均值兩倍的概率足夠低。對於桶中的每個query,都可以attend to自己以及前一個桶中相同hash 值的key。
小結下,LSH attention做了以下兩個事情:
第一,找到、矩陣的LSH hashes。
第二,在同一個hash桶內計算k和q向量的標準attention。
更具體來說可分爲以下5個步驟:
第一,令輸入序列queries=keys
第二,做LSH bucketing,即進行hash計算,得到每個query和key所歸屬的桶(不同顏色表示不同的桶)。
第三,根據桶編號對query進行排序,同個桶中,按照query原本的位置進行排序。
第四,對於排序後的新序列,進行 chunk 拆分
第五,對於每個query只attend自己以及自己之前的chunk,對於這些候選集中相同桶的key進行attend。
多輪LSH attention:
LSH 有近似性,即不能保證相似的輸入能在同一個桶中。爲了減輕這個問題,採用了multi-round LSH attention。即重複上述過程多次,以使類似的item以儘可能高的概率落入相同的桶中,儘量避免相似item落入不同桶。更多的細節參考附件A。
3. 可逆層
如上所述,attention的複雜度可以被減少爲與序列長度成正比,但是,參數量佔的複雜度依舊很高,如何進一步減少呢?這裏就開始嘗試解決前文介紹部分所提到的第二和第三個問題,即大量的encoder和decoder層、全連接層FFN的深度問題。
Reversible residual Network (RevNet)
RevNet的思想是每一層的activations可以根據下一層的activations推導獲得,從而不需要在內存中儲存activations。在原本的residual layer中,由公式輸出得到activations。其中F是residual 函數。在RevNet中,先將輸入分爲兩個部分和,然後通過不同residual functions: 和 得到輸出和:
再根據以下結構,從輸出獲得輸入:
Reversible Transformer
那麼如何在Transformer中引入RevNet?將attention layer和 FFN layer通過ResNet 連接,從而減少內存的消耗。具體是令F函數爲attention 層,G函數作爲FFN層。需要注意的一點是layer normalization是包含在residual blocks中的。
如此,使用可逆的Transformer在每一層中就無需存儲激活值,也就避免了這一項。可逆層代替標準的殘差層,可以在訓練過程中只存儲一次激活,而不是N次。
Chunking
上述消除了項的影響,深層的網絡仍然佔有大量內存。在FFN中中間隱藏層的緯度通常非常大,比如或者更大。由於FFN的計算與序列中的位置完全無關,因此計算可以被分割成個塊,以降低內存的使用。雖然該操作其實可並行處理,但是每次只計算一個chunk,通過時間換取內存空間。
另外,可逆操作和反向傳播操作也分塊處理。除FFN之外,對於詞彙量大的模型(單詞類型),還對輸出處的log- probability分塊,並一次計算序列各部分的損失。
4. 實驗結果
對圖像生成任務imagenet64(序列長度爲12K)和文本任務enwik8-64K(即序列長度爲64K)進行了實驗,評價了可逆層、共享query-key、LSH attention對內存、精度和速度的影響。
可逆層和共享query-key的影響:
Figure 3中的左部分驗證共享query-key的影響。從perplexity曲線結果可以看出,共享QK attention並不會明顯遜色於常規attention。且在enwik8數據集中收斂更快。換句話說,使用共享QK attention並不會犧牲準確性。
Figure 3中的右部分驗證的是可逆層的影響。實驗中對比的可逆層和常規Transformer參數量相同,且學習曲線看起來也幾乎相同。這些結果表明,可逆Transformer在節省內存的同時並不會犧牲精度。
LSH attention的影響:
如Figure 4所示,可以看出隨着hash數的增多精度也提升了。
更大的Reformer模型:
Figure 5展示了不同層數的Reformer在envik8和imagenet64上的表現。下圖(左)是Big Reformer隨層數變化指標結果,20層依然無壓力。而下圖(右)是普通attention和LSH attention在不同序列長度的速度比較,當序列很長的時候,LSH具有顯著的優勢。
5. 總結
Reformer將Transformer的建模能力與能夠在長序列上高效執行的體系結構相結合,使其即使處理大模型時,也可以使用較小的內存。這將有助於大型、海量參數化的Transformer模型變得更廣泛可用。此外,處理長序列的能力爲Reformer在許多生成任務上的使用開闢了道路。除了生成非常長的連貫的文本外
Reformer可以把Transformer模型的能力帶到其他領域,如時間序列預測、音樂、圖像等。