本文首发于微信公众号:NewBeeNLP,欢迎关注获取更多干货资源。
上一篇Transformers Assemble(PART III) 重点在transformer位置信息的改进,这一集挑选了几篇都带有「Sparse」的标签,主要关注点在于transformer结构的复杂度问题。先来看看都有哪些:
来自OpenAI的工作,同样关注于原始Transformer的计算复杂度问题,尤其是在面对长序列输入的情况。为此,论文中将full attention进行分解,通过多个sparse attention来代替,在不牺牲性能的情况下降低复杂度至O ( n n ) O(n \sqrt{n}) O ( n n ) 。
下图是三种不同的注意力形式,其中上半部分表示一个6x6图像的像素之间相互attend,下半部分表示对应的connectivity matrix。
(a)原始Transformer的full attention;
(b)Strided Attention :这种方式主要应用于图像或者音频,每一个位置通过attend其对应的行和列来获取信息,两个head的具体表示为:第一个head用于attend该位置前面的l l l 个位置,第二个head用于attend间隔l l l 的位置(如果输入是图像l l l 为图像的宽,则attend对应的列):
A i ( 1 ) = { t , t + 1 , … , i } for t = max ( 0 , i − l )
A_{i}^{(1)}=\{t, t+1, \ldots, i\} \text { for } t=\max (0, i-l)
A i ( 1 ) = { t , t + 1 , … , i } for t = max ( 0 , i − l ) A i ( 2 ) = { j : ( i − j ) m o d l = 0 }
A_{i}^{(2)}=\{j:(i-j) \bmod l=0\}
A i ( 2 ) = { j : ( i − j ) m o d l = 0 } (c)Fixed Attention :这种方式主要应用于像文本
之类没有周期性的数据,首先将文本分成固定长度的块,然后第一个head处理该块中该位置之前的所有元素,第二个head处理每个块的最后一部分的固定大小的部分。
Other Tricks
上面就是Attention主要的改进,文中还涉及了一些其他的tricks。
pre-activation residual block
来自Identity mappings in deep residual networks 可以使Transformer的训练更加容易
H 0 = embed ( X , W e ) H k = H k − 1 + resblock ( H k − 1 ) y = softmax ( norm ( H N ) W out )
\begin{aligned}
H_{0} &=\text { embed }\left(X, W_{e}\right) \\
H_{k} &=H_{k-1}+\text { resblock }\left(H_{k-1}\right) \\
y=& \operatorname{softmax}\left(\operatorname{norm}\left(H_{N}\right) W_{\text {out}}\right)
\end{aligned}
H 0 H k y = = embed ( X , W e ) = H k − 1 + resblock ( H k − 1 ) s o f t m a x ( n o r m ( H N ) W out ) 其中resblock(H)的计算如下a ( H ) = dropout ( attention ( norm ( H ) ) ) b ( H ) = dropout ( f f ( norm ( H + a ( H ) ) ) ) resblock ( H ) = a ( H ) + b ( H )
\begin{aligned}
a(H) &=\text { dropout }(\text { attention }(\operatorname{norm}(H))) \\
b(H) &=\operatorname{dropout}(\mathrm{ff}(\operatorname{norm}(H+a(H)))) \\
& \text { resblock }(H)=a(H)+b(H)
\end{aligned}
a ( H ) b ( H ) = dropout ( attention ( n o r m ( H ) ) ) = d r o p o u t ( f f ( n o r m ( H + a ( H ) ) ) ) resblock ( H ) = a ( H ) + b ( H )
Gradient check-pointing
Efficient block-sparse attention kernels
Mixed-precision training
reference
这篇论文也是对vanilla Transformer的改进,提出了Adaptively Sparse Transformers (AST),优化的两个关键就在其名字中:
Sparse: 通过替换Softmax函数为α − e n t m a x \alpha-entmax α − e n t m a x 达到稀疏注意力;
Adaptively: 每个attention head都是模型可自动学习的;
作者指出与先前的sparse transformer
(就是上面两个~)研究不同的是,他们的这一方法可以关注在不连续的输入集合,如下图:
Sparse Attention
softmax函数所有结果都不为0,并且最终所有元素之和为1,这样的特性决定了相对重要的部分的权值会“缩水”。这一方向的研究很多,作者选用了最近提出的alpha-entmax :
α -entmax ( z ) : = argmax p ∈ Δ d p ⊤ z + H α T ( p ) \alpha \text { -entmax }(\boldsymbol{z}):=\underset{\boldsymbol{p} \in \Delta^{d}}{\operatorname{argmax}} \boldsymbol{p}^{\top} \boldsymbol{z}+\mathrm{H}_{\boldsymbol{\alpha}}^{\mathrm{T}}(\boldsymbol{p})
α -entmax ( z ) : = p ∈ Δ d a r g m a x p ⊤ z + H α T ( p )
H α T ( p ) : = { 1 α ( α − 1 ) ∑ j ( p j − p j α ) , α ≠ 1 − ∑ j p j log p j , α = 1
\mathrm{H}_{\alpha}^{\mathrm{T}}(\boldsymbol{p}):=\left\{\begin{array}{ll}
{\frac{1}{\alpha(\alpha-1)} \sum_{j}\left(p_{j}-p_{j}^{\alpha}\right),} & {\alpha \neq 1} \\
{-\sum_{j} p_{j} \log p_{j},} & {\alpha=1}
\end{array}\right.
H α T ( p ) : = { α ( α − 1 ) 1 ∑ j ( p j − p j α ) , − ∑ j p j log p j , α = 1 α = 1
AST
对于Transformer类模型的功能至关重要的是,不同的head会捕获不同的语言现象,这让我们想到对于不同的head,使用不同的α \alpha α 值,使其自适应地让一些head稀疏化,一些head更接近softmax。利用上面的α − e n t m a x \alpha-entmax α − e n t m a x 替换原始的softmax函数后,将α \alpha α 看成是与其他网络参数意义的可学习参数,通过随机梯度进行优化。但是通过梯度方法对其自动优化并不容易,然后作者在下面就开始一系列数学推导。。。
TODO
我实在是看不动了。。。
后面的实验和分析也非常有意思的。。。
大家记得看,我先溜了。。。
PS. 在油管上发现了作者的分享视频,放在reference里
Reference
Motivation和上一篇论文一样,如下图,对于文本I thanked him with all my heart, and I asked him, 'why are you helping me?'
,vanilla Transformer(蓝色标记)会对所有元素都有注意,而噪音的注意力会对效果产生影响;新提出的显式稀疏注意力机制(橙色标记)只会关注文本的t o p k topk t o p k 个attention score最大的元素,从而移除无关信息。
具体实现也非常简单易于实现,且不会增加额外的内存和计算开销。
沿用vanilla transformer的attention计算公式得到attention score,
P = Q K T d
P=\frac{Q K^{\mathrm{T}}}{\sqrt{d}}
P = d Q K T
假定分值越大的元素其相关性越大,计算Masking矩阵。找出P P P 中每行的k k k 个最大元素,记录其位置,并得到一个threshold vector,t = [ t 1 , t 2 , ⋯ , t l Q ] t=\left[t_{1}, t_{2}, \cdots, t_{l_{Q}}\right] t = [ t 1 , t 2 , ⋯ , t l Q ]
将Making矩阵应用到原始P P P 矩阵上,
M ( P , k ) i j = { P i j if P i j ≥ t i ( k -th largest value of row i ) − ∞ if P i j < t i ( k -th largest value of row i )
\mathcal{M}(P, k)_{i j}=\left\{\begin{array}{ll}
{P_{i j}} & {\text { if } P_{i j} \geq t_{i}(k \text { -th largest value of row } i)} \\
{-\infty} & {\text { if } P_{i j}<t_{i}(k \text { -th largest value of row } i)}
\end{array}\right.
M ( P , k ) i j = { P i j − ∞ if P i j ≥ t i ( k -th largest value of row i ) if P i j < t i ( k -th largest value of row i ) 反向传播时,
∂ M i j ∂ P k l = 0 ( i ≠ k or j ≠ l ) ∂ M i j ∂ P i j = { 1 if P i j ≥ t i ( k − th largest value of row i ) 0 if P i j < t i ( k − th largest value of row i )
\begin{aligned}
&\frac{\partial M_{i j}}{\partial P_{k l}}=0 \quad(i \neq k \text { or } j \neq l)\\
&\frac{\partial M_{i j}}{\partial P_{i j}}=\left\{\begin{array}{ll}
{1} & {\text { if } P_{i j} \geq t_{i}(k-\text { th largest value of row } i)} \\
{0} & {\text { if } P_{i j}<t_{i}(k-\text { th largest value of row } i)}
\end{array}\right.
\end{aligned}
∂ P k l ∂ M i j = 0 ( i = k or j = l ) ∂ P i j ∂ M i j = { 1 0 if P i j ≥ t i ( k − th largest value of row i ) if P i j < t i ( k − th largest value of row i )
归一化,
A = softmax ( M ( P , k ) )
A=\operatorname{softmax}(\mathcal{M}(P, k))
A = s o f t m a x ( M ( P , k ) ) 反向传播时,
∂ A i j ∂ P k l = ∑ m = 1 l Q ∑ n = 1 l K ∂ A i j ∂ M m n ∂ M m n ∂ P k l = ∂ A i j ∂ M k l ∂ M k l ∂ P k l = { ∂ A i j ∂ M k l if P i j ≥ t i ( k -th largest value of row i ) 0 if P i j < t i ( k -th largest value of row i )
\begin{aligned}
\frac{\partial A_{i j}}{\partial P_{k l}} &=\sum_{m=1}^{l_{Q}} \sum_{n=1}^{l_{K}} \frac{\partial A_{i j}}{\partial M_{m n}} \frac{\partial M_{m n}}{\partial P_{k l}} \\
&=\frac{\partial A_{i j}}{\partial M_{k l}} \frac{\partial M_{k l}}{\partial P_{k l}} \\
&=\left\{\begin{array}{cl}
{\frac{\partial A_{i j}}{\partial M_{k l}}} & {\text { if } P_{i j} \geq t_{i}(k \text { -th largest value of row } i)} \\
{0} & {\text { if } P_{i j}<t_{i}(k \text { -th largest value of row } i)}
\end{array}\right.
\end{aligned}
∂ P k l ∂ A i j = m = 1 ∑ l Q n = 1 ∑ l K ∂ M m n ∂ A i j ∂ P k l ∂ M m n = ∂ M k l ∂ A i j ∂ P k l ∂ M k l = { ∂ M k l ∂ A i j 0 if P i j ≥ t i ( k -th largest value of row i ) if P i j < t i ( k -th largest value of row i )
输出表示,
C = A V
C=A V
C = A V
整体流程如下最右图所示,
另外,参数k k k 的选择至关重要,当k k k 取与序列长度一致时即为vanilla transformer。作者在NMT实验中发现当k = 8 k=8 k = 8 时效果最好。
reference