ICLR2020论文阅读笔记reformer: THE EFFICIENT TRANSFORMER

0. 背景

机构:Google Research 、U.C. Berkeley
作者:Nikita Kitaev、Łukasz Kaiser、Anselm Levskaya
论文地址:https://arxiv.org/abs/2001.04451
收录会议:ICLR2020
论文代码:https://github.com/google/trax/tree/master/trax/models/reformer

0.1 摘要

基于Transformer的各种巨型模型在各种自然语言处理任务中常常能够取得最优结果,但这些模型的训练成本往往过高,在针对长序列文本上尤甚。为此,本文提出两种技术以改善基于Transformer的这类模型,名为Reformer。第一,使用局部敏感hash,替换原始的点乘方式的attention,从而将其空间复杂度从O(L2)\mathrm{O} (L^{2})降低到O(LlogL)\mathrm{O}(L \log L),其中LL表示文本序列的长度。第二,使用逆残差层代替标准的残差,这使得训练过程中只需存储一次激活值,而无需NN次,其中NN表示网络层数。最终的结果表明Reformer性能与Transformer相当,同时在长序列上具有更高的内存效率和更快的速度。

1. 介绍

先看看Transformer模型是否真的那么占用资源或者说低效。以现有的最大Transformer层为例,该Transformer层中参数量是0.5B,这需要2GB的内存。(1M=1024KB,1KB=1024Byte。所以1GB=1024M=1024x1024KB=1024x1024x1024Byte=1073741824Byte。float占用4个Byte。0.5B即5亿参数,需要的内存量为5亿*4字节=20亿字节。这差不多是1.86GB即约为2GB)对于由64Ktokens组成的序列,如果嵌入层的尺寸是1024,batch size是8,那么激活值需要64K×1K×8=0.5B64K \times 1K \times 8=0.5B个浮点数来存储,这又需要2GB的内存。如果每层的内存占用只有上述提到的这些的话,那么在单加速器上使用Transformer处理64K长度的序列也是轻而易举。此外,如此前提下训练BERT的整个语料库也只需17GB的内存。然而,现实并非如此,真实环境下为何甚至不能在单台机器上对这些模型进行微调呢?

这是因为上述仅仅考虑单层参数的内存占用和输入激活值的内存消耗,而忽略了 Transformer 在内存占用上的主要问题:

  • 长度为LL的序列的 attention 的时间和空间复杂度是 O(L2)O(L^2),所以对于 64K tokens的序列就会耗尽内存。

  • 需要存储激活值用于反向传播,那么N层模型内存占用是单层的N倍;

  • 由于中间全连接层的深度dffd_{ff}通常远大于注意力激活层的深度dmodeld_{model},而这需要占用很大的内存;

为此,本文提出Reformer模型以解决上述问题,具体采用如下方案:

  • 采用基于局部敏感哈希(locality-sensitive hashing,LSH)的近似注意力计算,让注意力层的O(L2)O(L^2) 因子变为O(LlogL)\mathrm{O}(L \log L) ,这使得在长序列上的处理成为可能。

  • 可逆层(Reversible layer),在整个模型中只使用单个副本,所以可以消除层数因子NN

  • 在前馈层(feed-forward layer)分开激活和分块处理,从而消除dffd_{ff}因子的影响,降低前馈层的内存占用;

Reformer模型在以下3个任务上进行实验:合成任务、文本任务(enwik8,序列长度为64K)和图像生成任务(imagenet-64,序列长度为12K)。实验结果表明Reformer结果与Transformer相当,但是更快、内存也更高效。

2. 局部敏感哈希Attention

点乘attention:
标准的Transformer使用点乘的attention,queries和keys的维度都是dkd_k,values的维度是dvd_v。query先与key做点乘,再除以dk\sqrt{d_k},再输入到softmax中得到value的权重,最后权重再与value相乘,得到最终的结果。在实际操作过程中是以矩阵方式进行批量操作,queries组成矩阵QQ,keys组成矩阵KK,values组成矩阵VV,上述流程概况如下:
 Attention (Q,K,V)=softmax(QKTdk)V \text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V

多头attention:
上述的attention操作并行地进行h次,再输出维度为dvd_v的输出结果。再将这些结果拼接,再做一次投射操作得到最终的结果。即所谓的多头attention。

高效内存attention:
先来算下上述attention机制消耗的内存。假设QKVQ,K,V的尺寸为[batch_size,length,d_model]。QKTQK^T的尺寸为[batch_size,length,length]。实验中发现,当length=64k,即使batch_size=1,那么64k*64k大小的矩阵,如果用32位浮点数来存储的话,需要16GB内存。鉴于此,在长序列上使用Transformer显得不切实际。但是需要注意的是,QKTQK^T矩阵可以不必全部放在内存中,可以对每个query分别计算attention,那么只需要在内存计算softmax(qiKTdk)V\operatorname{softmax}\left(\frac{q_i K^{T}}{\sqrt{d_{k}}}\right) V。反向传播计算梯度时再重新计算一次。这种方式计算attention虽然低效,但是所占用的内存与length成正比。这种方法在本文这里作为一种全attention的baseline。

Q,K,V从何处来?
上述讨论了Q、K、V,但是一般我们只会得到大小为[batch_size,length,d_model]的激活值AA,这些值是token的嵌入所组成的句向量。那么为了从A中得到Q、K、V,Transformer使用了3个不同的线性层(参数不同)将A投射为Q、K、V。对于使用局部敏感哈希attention的模型,我们希望queries和keys(即Q和K)相同。只需要A投射到Q和A投射到K时采用相同线性变换参数即可,而A投射到V时采用不同参数。这种方式成为共享QK-Transformer。实验表明共享QK并不会影响Transformer的性能,即使添加一项dkd_k的归一化项。

Hashing attention:
在LSH attention中,假设Q、K、V的尺寸为[batch_size,length,d_model],同时仍然使用此前介绍的多头attention机制。那么QKTQK^T的尺寸为[batch_size,length,length]。由于softmax(QKT)\operatorname{softmax}(QK^T)的计算结果主要取决于值最大的部分,对于每个query只需关注KK中与query最接近的点。当KK的长度是64k,那么对个每个query,本文仅仅考虑其最近的的32或64个keys。如此会更加高效,那么如何找寻最近的那些keys呢?

局部敏感哈希(LSH):
在高纬空间中找寻最近邻可以使用局部敏感哈希(LSH)。将每个向量x通过hash函数h(x)进行映射,如果近处的向量获得相同的hash,且具有高概率,而远处的向量没有,那么这样的hash称为位置敏感型hash。在此处例子中,我们实际上只要求近邻的向量以高概率具有相同的hash值,并且hash桶也以高概率具有相同的大小。

具体是使用如Figure 1所示的随机投射方法:
在这里插入图片描述
上图的angular LSH是常用LSH算法的一个变体,它将点投射到一个单位球上,这个单位球被划分为预定义的区域,每个区域都有一个特定的代码。然后一系列随机旋转的点定义了这些点所归属的桶。让我们通过一个简单的2D例子来说明这一点,

angular LSH的动图说明,图片来源

这里有两个点,它们投影到一个单位圆上,并以不同的角度随机旋转3次。可以观察到,它们不太可能共享同一个hash桶。在后续例子中,可以看到两个非常接近的点在3次随机旋转后会位于相同的hash桶:

Angular LSH最近邻搜索的的一个简化动画:两个点很接近的情况。图片来源

如果想要得到b个hash,那么先固定一个随机矩阵R的大小为[dkb/2][d_k,b/2]。再定义h(x)=softmax([xR;xR])h(x)=\operatorname{softmax}([xR;-xR]),其中[u;v][u;v]表示两个向量之间的拼接。

LSH attention:
综合考虑上述的LSH策略和hashing attention,先重写单个query在位置i的常规attention:
oi=jPiexp(qikjz(i,Pi))vj where Pi={j:ij}\begin{aligned} &o_{i}=\sum_{j \in \mathcal{P}_{i}} \exp \left(q_{i} \cdot k_{j}-z\left(i, \mathcal{P}_{i}\right)\right) v_{j}\\ &\text { where } \mathcal{P}_{i}=\{j: i \geq j\} \end{aligned}
其中Pi\mathcal{P}_i表示query在位置i所需要attend的集合,zz表示配分函数(partition function)比如softmax中的归一化项。为了书写清楚,这里省略了缩放项dk\sqrt{d_k}

对于批量操作,定义批量操作集合P~i={0,1,,l}Pi\tilde{\mathcal{P}}_{i}=\{0,1, \ldots, l\} \supseteq \mathcal{P}_{i},当遮蔽的元素不在Pi\mathcal{P}_{i}中,此时常规attention定义如下:
oi=jP~iexp(qikjm(j,Pi)z(i,Pi))vj where m(j,Pi)={ if jPi0 otherwise o_{i}=\sum_{j \in \widetilde{\mathcal{P}}_{i}} \exp \left(q_{i} \cdot k_{j}-m\left(j, \mathcal{P}_{i}\right)-z\left(i, \mathcal{P}_{i}\right)\right) v_{j} \quad \text { where } m\left(j, \mathcal{P}_{i}\right)=\left\{\begin{array}{ll} \infty & \text { if } j \notin \mathcal{P}_{i} \\ 0 & \text { otherwise } \end{array}\right.
即对于不能attend到的位置,m(j,Pi)m(j, \mathcal{P}_{i})为正无穷,那么qikjq_{i} \cdot k_{j}减去正无穷再去exp操作,其结果为0。这样就不需要对于每个位置i都有单独的Pi\mathcal{P}_i

在LSH attention中,query中位置i所能够attend的限制集合Pi\mathcal{P}_{i}被限制到一个hash桶中:
Pi={j:h(qi)=h(kj)}\mathcal{P}_{i}=\left\{j: h\left(q_{i}\right)=h\left(k_{j}\right)\right\}
Figure 2(a-b)展示的是全attention和hash attention的对比。
在这里插入图片描述
图a:常规的attention机制中,黑点代表的是softmax中占主导的位置。注意这边的attention使用的是encoder的attention, 否则q3q_3 无法attend to q6q_6。另外,这种全attention(即encoder中的attention)的attention矩阵一般是稀疏的,但计算中并没有利用这种稀疏性,所以可以利用这个降低时间空间复杂度。

图b:计算query和key所归属的hash桶。再按照桶进行排序,同一个桶又按照原本的位置进行排序得到图b。可以看到,同一个桶,可以出现多个query但keys很少的情况,例如图中蓝色的桶query有3个,都attend到同一个key中。由于相似的item很有可能落在同一个桶里,所以只在每个桶内部进行attention就可以近似全attention。

图c:为了减缓桶中q和k不均衡问题,本文通过令kj=qjqjk_{j}=\frac{q_{j}}{\left\|q_{j}\right\|}使得h(kj)=h(qj)h(k_{j})=h(q_{j}),即使用了share-QK attention。然后先按照桶序号对queries排序,每个桶中,仍按照原本的position 位置大小排序。得到图c。对比b图和c图可以看出,纵轴的k已经变成了q。时候就能保证对角线都是attend 到的而且q和k在桶中的个数一样(因为Q=K)。排序后的attention矩阵,相同桶的值会在对角线附近聚集。注意到图中对角线的点为空心,这是因为虽然在正常情况下,q会attend to本身位置的value,但是在share-QK的实现下,如果attend to本身,会导致其值特别大,其他的值特别小,经过softmax之后,其他都是0,就自己本身是1。所以为了避免这种情况,q不会去attend 自身位置的值,除非只有自己本身可以attend。

图d:即使Q=K,还是会出现一个问题:有的桶中个数多,有的桶中个数少。比如一个极端情况,2个桶,其中一个桶占据了所有的keys,另一个桶为空,那么LSH attention就没有起作用。于是在图c的基础上,增加了chunk的操作。对输入进行排序之后(即图c中先桶排序,同个桶内按照token 的 position排序)得到新的序列顺序sis_i,比如图中原来的序列顺序是[q1,q2,q3,q4,q5,q6][q_1,q_2,q_3,q_4,q_5,q_6],新的序列顺序是[q1,q2,q4,q3,q6,q5][q_1,q_2,q_4,q_3,q_6,q_5] 。每个chunk内query的上限个数为m=2lnbucketsm=\frac{2 l}{n_{\text {buckets}}}, (ll 为输入query的长度) ,每个桶平均大小为m=lnbucketsm=\frac{l}{n_{\text {buckets}}},这里假设桶中数量增加到均值两倍的概率足够低。对于桶中的每个query,都可以attend to自己以及前一个桶中相同hash 值的key。

小结下,LSH attention做了以下两个事情:
第一,找到QQKK矩阵的LSH hashes。
第二,在同一个hash桶内计算k和q向量的标准attention。

更具体来说可分为以下5个步骤:
第一,令输入序列queries=keys
第二,做LSH bucketing,即进行hash计算,得到每个query和key所归属的桶(不同颜色表示不同的桶)。
第三,根据桶编号对query进行排序,同个桶中,按照query原本的位置进行排序。
第四,对于排序后的新序列,进行 chunk 拆分
第五,对于每个query只attend自己以及自己之前的chunk,对于这些候选集中相同桶的key进行attend。

多轮LSH attention:
LSH 有近似性,即不能保证相似的输入能在同一个桶中。为了减轻这个问题,采用了multi-round LSH attention。即重复上述过程多次,以使类似的item以尽可能高的概率落入相同的桶中,尽量避免相似item落入不同桶。更多的细节参考附件A。

3. 可逆层

如上所述,attention的复杂度可以被减少为与序列长度成正比,但是,参数量占的复杂度依旧很高,如何进一步减少呢?这里就开始尝试解决前文介绍部分所提到的第二和第三个问题,即大量的encoder和decoder层、全连接层FFN的深度问题。

Reversible residual Network (RevNet)

RevNet的思想是每一层的activations可以根据下一层的activations推导获得,从而不需要在内存中储存activations。在原本的residual layer中,由公式y=x+F(x)y=x+F(x)输出得到activations。其中F是residual 函数。在RevNet中,先将输入xx分为两个部分x1x_1x2x_2,然后通过不同residual functions:F()F(\cdot)G()G(\cdot)得到输出y1y_1y2y_2
y1=x1+F(x2)y2=x2+G(y1)y_{1}=x_{1}+F\left(x_{2}\right) \quad y_{2}=x_{2}+G\left(y_{1}\right)
再根据以下结构,从输出获得输入:
x2=y2G(y1)x1=y1F(x2) x_{2}=y_{2}-G\left(y_{1}\right) \quad x_{1}=y_{1}-F\left(x_{2}\right)

Reversible Transformer

那么如何在Transformer中引入RevNet?将attention layer和 FFN layer通过ResNet 连接,从而减少内存的消耗。具体是令F函数为attention 层,G函数作为FFN层。需要注意的一点是layer normalization是包含在residual blocks中的。
Y1=X1+ Attention (X2)Y2=X2+ FeedForward (Y1)Y_{1}=X_{1}+\text { Attention }\left(X_{2}\right) \quad Y_{2}=X_{2}+\text { FeedForward }\left(Y_{1}\right)
如此,使用可逆的Transformer在每一层中就无需存储激活值,也就避免了nln_l这一项。可逆层代替标准的残差层,可以在训练过程中只存储一次激活,而不是N次。

Chunking

上述消除了nln_l项的影响,深层的网络仍然占有大量内存。在FFN中中间隐藏层的纬度通常非常大,比如dff=4kd_{ff}=4k或者更大。由于FFN的计算与序列中的位置完全无关,因此计算可以被分割成cc个块,以降低内存的使用。虽然该操作其实可并行处理,但是每次只计算一个chunk,通过时间换取内存空间。
Y2=[Y2(1);;Y2(c)]=[X2(1)+ FeedForward (Y1(1));;X2(c)+ FeedForward (Y1(c))]Y_{2}=\left[Y_{2}^{(1)} ; \ldots ; Y_{2}^{(c)}\right]=\left[X_{2}^{(1)}+\text { FeedForward }\left(Y_{1}^{(1)}\right) ; \ldots ; X_{2}^{(c)}+\text { FeedForward }\left(Y_{1}^{(c)}\right)\right]
另外,可逆操作和反向传播操作也分块处理。除FFN之外,对于词汇量大的模型(单词类型>dmodel>d_{model}),还对输出处的log- probability分块,并一次计算序列各部分的损失。

4. 实验结果

对图像生成任务imagenet64(序列长度为12K)和文本任务enwik8-64K(即序列长度为64K)进行了实验,评价了可逆层、共享query-key、LSH attention对内存、精度和速度的影响。

可逆层和共享query-key的影响:
在这里插入图片描述

Figure 3中的左部分验证共享query-key的影响。从perplexity曲线结果可以看出,共享QK attention并不会明显逊色于常规attention。且在enwik8数据集中收敛更快。换句话说,使用共享QK attention并不会牺牲准确性。
Figure 3中的右部分验证的是可逆层的影响。实验中对比的可逆层和常规Transformer参数量相同,且学习曲线看起来也几乎相同。这些结果表明,可逆Transformer在节省内存的同时并不会牺牲精度。

LSH attention的影响:
如Figure 4所示,可以看出随着hash数的增多精度也提升了。
在这里插入图片描述

更大的Reformer模型:
Figure 5展示了不同层数的Reformer在envik8和imagenet64上的表现。下图(左)是Big Reformer随层数变化指标结果,20层依然无压力。而下图(右)是普通attention和LSH attention在不同序列长度的速度比较,当序列很长的时候,LSH具有显著的优势。
在这里插入图片描述

5. 总结

Reformer将Transformer的建模能力与能够在长序列上高效执行的体系结构相结合,使其即使处理大模型时,也可以使用较小的内存。这将有助于大型、海量参数化的Transformer模型变得更广泛可用。此外,处理长序列的能力为Reformer在许多生成任务上的使用开辟了道路。除了生成非常长的连贯的文本外
Reformer可以把Transformer模型的能力带到其他领域,如时间序列预测、音乐、图像等。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章