本文首發於微信公衆號:NewBeeNLP,歡迎關注獲取更多幹貨資源。
上一篇Transformers Assemble(PART III) 重點在transformer位置信息的改進,這一集挑選了幾篇都帶有「Sparse」的標籤,主要關注點在於transformer結構的複雜度問題。先來看看都有哪些:
來自OpenAI的工作,同樣關注於原始Transformer的計算複雜度問題,尤其是在面對長序列輸入的情況。爲此,論文中將full attention進行分解,通過多個sparse attention來代替,在不犧牲性能的情況下降低複雜度至O ( n n ) O(n \sqrt{n}) O ( n n ) 。
下圖是三種不同的注意力形式,其中上半部分表示一個6x6圖像的像素之間相互attend,下半部分表示對應的connectivity matrix。
(a)原始Transformer的full attention;
(b)Strided Attention :這種方式主要應用於圖像或者音頻,每一個位置通過attend其對應的行和列來獲取信息,兩個head的具體表示爲:第一個head用於attend該位置前面的l l l 個位置,第二個head用於attend間隔l l l 的位置(如果輸入是圖像l l l 爲圖像的寬,則attend對應的列):
A i ( 1 ) = { t , t + 1 , … , i } for t = max ( 0 , i − l )
A_{i}^{(1)}=\{t, t+1, \ldots, i\} \text { for } t=\max (0, i-l)
A i ( 1 ) = { t , t + 1 , … , i } for t = max ( 0 , i − l ) A i ( 2 ) = { j : ( i − j ) m o d l = 0 }
A_{i}^{(2)}=\{j:(i-j) \bmod l=0\}
A i ( 2 ) = { j : ( i − j ) m o d l = 0 } (c)Fixed Attention :這種方式主要應用於像文本
之類沒有周期性的數據,首先將文本分成固定長度的塊,然後第一個head處理該塊中該位置之前的所有元素,第二個head處理每個塊的最後一部分的固定大小的部分。
Other Tricks
上面就是Attention主要的改進,文中還涉及了一些其他的tricks。
pre-activation residual block
來自Identity mappings in deep residual networks 可以使Transformer的訓練更加容易
H 0 = embed ( X , W e ) H k = H k − 1 + resblock ( H k − 1 ) y = softmax ( norm ( H N ) W out )
\begin{aligned}
H_{0} &=\text { embed }\left(X, W_{e}\right) \\
H_{k} &=H_{k-1}+\text { resblock }\left(H_{k-1}\right) \\
y=& \operatorname{softmax}\left(\operatorname{norm}\left(H_{N}\right) W_{\text {out}}\right)
\end{aligned}
H 0 H k y = = embed ( X , W e ) = H k − 1 + resblock ( H k − 1 ) s o f t m a x ( n o r m ( H N ) W out ) 其中resblock(H)的計算如下a ( H ) = dropout ( attention ( norm ( H ) ) ) b ( H ) = dropout ( f f ( norm ( H + a ( H ) ) ) ) resblock ( H ) = a ( H ) + b ( H )
\begin{aligned}
a(H) &=\text { dropout }(\text { attention }(\operatorname{norm}(H))) \\
b(H) &=\operatorname{dropout}(\mathrm{ff}(\operatorname{norm}(H+a(H)))) \\
& \text { resblock }(H)=a(H)+b(H)
\end{aligned}
a ( H ) b ( H ) = dropout ( attention ( n o r m ( H ) ) ) = d r o p o u t ( f f ( n o r m ( H + a ( H ) ) ) ) resblock ( H ) = a ( H ) + b ( H )
Gradient check-pointing
Efficient block-sparse attention kernels
Mixed-precision training
reference
這篇論文也是對vanilla Transformer的改進,提出了Adaptively Sparse Transformers (AST),優化的兩個關鍵就在其名字中:
Sparse: 通過替換Softmax函數爲α − e n t m a x \alpha-entmax α − e n t m a x 達到稀疏注意力;
Adaptively: 每個attention head都是模型可自動學習的;
作者指出與先前的sparse transformer
(就是上面兩個~)研究不同的是,他們的這一方法可以關注在不連續的輸入集合,如下圖:
Sparse Attention
softmax函數所有結果都不爲0,並且最終所有元素之和爲1,這樣的特性決定了相對重要的部分的權值會“縮水”。這一方向的研究很多,作者選用了最近提出的alpha-entmax :
α -entmax ( z ) : = argmax p ∈ Δ d p ⊤ z + H α T ( p ) \alpha \text { -entmax }(\boldsymbol{z}):=\underset{\boldsymbol{p} \in \Delta^{d}}{\operatorname{argmax}} \boldsymbol{p}^{\top} \boldsymbol{z}+\mathrm{H}_{\boldsymbol{\alpha}}^{\mathrm{T}}(\boldsymbol{p})
α -entmax ( z ) : = p ∈ Δ d a r g m a x p ⊤ z + H α T ( p )
H α T ( p ) : = { 1 α ( α − 1 ) ∑ j ( p j − p j α ) , α ≠ 1 − ∑ j p j log p j , α = 1
\mathrm{H}_{\alpha}^{\mathrm{T}}(\boldsymbol{p}):=\left\{\begin{array}{ll}
{\frac{1}{\alpha(\alpha-1)} \sum_{j}\left(p_{j}-p_{j}^{\alpha}\right),} & {\alpha \neq 1} \\
{-\sum_{j} p_{j} \log p_{j},} & {\alpha=1}
\end{array}\right.
H α T ( p ) : = { α ( α − 1 ) 1 ∑ j ( p j − p j α ) , − ∑ j p j log p j , α = 1 α = 1
AST
對於Transformer類模型的功能至關重要的是,不同的head會捕獲不同的語言現象,這讓我們想到對於不同的head,使用不同的α \alpha α 值,使其自適應地讓一些head稀疏化,一些head更接近softmax。利用上面的α − e n t m a x \alpha-entmax α − e n t m a x 替換原始的softmax函數後,將α \alpha α 看成是與其他網絡參數意義的可學習參數,通過隨機梯度進行優化。但是通過梯度方法對其自動優化並不容易,然後作者在下面就開始一系列數學推導。。。
TODO
我實在是看不動了。。。
後面的實驗和分析也非常有意思的。。。
大家記得看,我先溜了。。。
PS. 在油管上發現了作者的分享視頻,放在reference裏
Reference
Motivation和上一篇論文一樣,如下圖,對於文本I thanked him with all my heart, and I asked him, 'why are you helping me?'
,vanilla Transformer(藍色標記)會對所有元素都有注意,而噪音的注意力會對效果產生影響;新提出的顯式稀疏注意力機制(橙色標記)只會關注文本的t o p k topk t o p k 個attention score最大的元素,從而移除無關信息。
具體實現也非常簡單易於實現,且不會增加額外的內存和計算開銷。
沿用vanilla transformer的attention計算公式得到attention score,
P = Q K T d
P=\frac{Q K^{\mathrm{T}}}{\sqrt{d}}
P = d Q K T
假定分值越大的元素其相關性越大,計算Masking矩陣。找出P P P 中每行的k k k 個最大元素,記錄其位置,並得到一個threshold vector,t = [ t 1 , t 2 , ⋯ , t l Q ] t=\left[t_{1}, t_{2}, \cdots, t_{l_{Q}}\right] t = [ t 1 , t 2 , ⋯ , t l Q ]
將Making矩陣應用到原始P P P 矩陣上,
M ( P , k ) i j = { P i j if P i j ≥ t i ( k -th largest value of row i ) − ∞ if P i j < t i ( k -th largest value of row i )
\mathcal{M}(P, k)_{i j}=\left\{\begin{array}{ll}
{P_{i j}} & {\text { if } P_{i j} \geq t_{i}(k \text { -th largest value of row } i)} \\
{-\infty} & {\text { if } P_{i j}<t_{i}(k \text { -th largest value of row } i)}
\end{array}\right.
M ( P , k ) i j = { P i j − ∞ if P i j ≥ t i ( k -th largest value of row i ) if P i j < t i ( k -th largest value of row i ) 反向傳播時,
∂ M i j ∂ P k l = 0 ( i ≠ k or j ≠ l ) ∂ M i j ∂ P i j = { 1 if P i j ≥ t i ( k − th largest value of row i ) 0 if P i j < t i ( k − th largest value of row i )
\begin{aligned}
&\frac{\partial M_{i j}}{\partial P_{k l}}=0 \quad(i \neq k \text { or } j \neq l)\\
&\frac{\partial M_{i j}}{\partial P_{i j}}=\left\{\begin{array}{ll}
{1} & {\text { if } P_{i j} \geq t_{i}(k-\text { th largest value of row } i)} \\
{0} & {\text { if } P_{i j}<t_{i}(k-\text { th largest value of row } i)}
\end{array}\right.
\end{aligned}
∂ P k l ∂ M i j = 0 ( i = k or j = l ) ∂ P i j ∂ M i j = { 1 0 if P i j ≥ t i ( k − th largest value of row i ) if P i j < t i ( k − th largest value of row i )
歸一化,
A = softmax ( M ( P , k ) )
A=\operatorname{softmax}(\mathcal{M}(P, k))
A = s o f t m a x ( M ( P , k ) ) 反向傳播時,
∂ A i j ∂ P k l = ∑ m = 1 l Q ∑ n = 1 l K ∂ A i j ∂ M m n ∂ M m n ∂ P k l = ∂ A i j ∂ M k l ∂ M k l ∂ P k l = { ∂ A i j ∂ M k l if P i j ≥ t i ( k -th largest value of row i ) 0 if P i j < t i ( k -th largest value of row i )
\begin{aligned}
\frac{\partial A_{i j}}{\partial P_{k l}} &=\sum_{m=1}^{l_{Q}} \sum_{n=1}^{l_{K}} \frac{\partial A_{i j}}{\partial M_{m n}} \frac{\partial M_{m n}}{\partial P_{k l}} \\
&=\frac{\partial A_{i j}}{\partial M_{k l}} \frac{\partial M_{k l}}{\partial P_{k l}} \\
&=\left\{\begin{array}{cl}
{\frac{\partial A_{i j}}{\partial M_{k l}}} & {\text { if } P_{i j} \geq t_{i}(k \text { -th largest value of row } i)} \\
{0} & {\text { if } P_{i j}<t_{i}(k \text { -th largest value of row } i)}
\end{array}\right.
\end{aligned}
∂ P k l ∂ A i j = m = 1 ∑ l Q n = 1 ∑ l K ∂ M m n ∂ A i j ∂ P k l ∂ M m n = ∂ M k l ∂ A i j ∂ P k l ∂ M k l = { ∂ M k l ∂ A i j 0 if P i j ≥ t i ( k -th largest value of row i ) if P i j < t i ( k -th largest value of row i )
輸出表示,
C = A V
C=A V
C = A V
整體流程如下最右圖所示,
另外,參數k k k 的選擇至關重要,當k k k 取與序列長度一致時即爲vanilla transformer。作者在NMT實驗中發現當k = 8 k=8 k = 8 時效果最好。
reference