在上一篇【機器學習】從RNN到Attention 中篇 從Seq2Seq到Attention in Seq2Seq中我們介紹了基於RNN結構的Attention機制,Attention機制通過encoder和注意力權重可以觀察到全局信息,從而較好地解決了長期依賴的問題,但是RNN的結構本身的輸入依賴於前一時刻模型的輸出,因此無法並行化。既然Attention機制本身就具有捕捉全局信息的能力,那麼我們是否可以拋開RNN結構,只使用Attention機制,從而既能捕捉全局信息,又能並行化呢?Transformer模型就使用了這樣一種思路。
從一個例子看Self-Attention
Self-Attention的核心在於學習序列中其他部分對於該部分的權重值,比如
The animal didn't cross the street because it was too tired
其中的“it”代指的是 the street還是The animal呢?self-attention的神奇之處在於可以讓模型更關注於The animal,以便更好地解讀句子的含義,如下圖所示
模型的整體結構
Transformer本質上是一個encoder-decoder的結構,如下圖所示:
如上圖,圖左的encoder部分由6個相同的子encoder組成,圖右的encoder部分也是由6個相同的子decoder組成。
其中的6個子encoder包含self-attention和FFN(前饋神經網絡)兩部分,6個子decoder包含self-attention,Encoder-Decoder Attention和FFN三部分。如下圖所示
Self-Attention結構
我們先來看self-attention部分
self-attention主要由三個矩陣Q,K,V構成,Q(Query), K(Key), V(Value)三個矩陣均來自於輸入X
圖中的爲模型參數,可以通過優化算法學習得到,輸入X與進行矩陣乘法後得到矩陣Q(Query), K(Key), V(Value),三者經過Attention操作後得到注意力矩陣,公式爲
有點複雜,我們來看一個例子,假如要翻譯一個詞組Thinking Machines,其中用 表示Thinking的輸入的embedding vector,用表示Machines的embedding vector。
我們來計算Thinking這個詞與其他詞的Attention Score,Attention Score的物理含義就是將當前的詞(Thinking)作爲搜索的query,來表示和句子中包含自身在內的所有詞(key)的相關性,如上圖中的q1k1和q1k2分別表示Thinking與Thinking自身的相關性以及Thinking與Machines的相關性,當前單詞與其自身的attention score最大,其他單詞與當前單詞相關性通過attention score。然後我們在用這些attention score與value vector相乘,得到加權的向量。
的作用在於放縮,從上例看112和96如果不做放縮直接進行softmax,則Thinking與Machines的Attention Score趨近於0,放縮後則略平衡一些。之所以選擇是基於假設:如果q和k的每一維滿足均值0,方差1的隨機變量,則它們的點乘滿足均值爲0,方差爲,除以它相當於對數據做標準化。
源碼錶示如下:
def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) \
/ math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
注意這裏其實是可以帶dropout層的,而masked_fill的作用在於讓我們在decoder中,只能關注到到當前單詞之前的、已經翻譯過的輸出的單詞,當前單詞之後的單詞則不被關注到。
FFN(前向傳播網絡)結構
self-attention的輸出後接入的是一個FFN(前向傳播網絡)結構,如下圖所示
先經過一個relu然後再過一個線性加權,可以看到無論是self-attention還是FFN都不再依賴於前一時刻的輸入,因此transformer的整個計算過程是可以並行的。
源碼錶示如下
class PositionwiseFeedForward(nn.Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
同樣,這裏可以帶dropout操作
Multihead attention結構
所謂Multihead attention結構就是同一輸入對應的self-attention層具有多個Q,K,V的組合,如下圖所示
Multihead attention結構的主要作用在於:提高了模型的表達能力,一組Q,K,V可以只捕捉句中單詞的一組相關關係,也就是“表達子空間”,多組Q,K,V則可以捕捉句中單詞的多組相關關係。在一層self-attention中共有八組Q,K,V,經過計算得到的,顯然,這個計算過程也是可以並行的。
那麼問題來了,下一層FFN並不能直接處理8個矩陣輸入,而是需要一個矩陣,這8組z該如何處理呢?答案是將它們鏈接(concat)起來後,再送入FFN,如下圖所示
根據上面“The animal didn’t cross the street because it was too tired”的例子,此時的Multihead結構可以捕捉到單詞間的如下關係
位置編碼
從上面的介紹我們發現transformer可以取代RNN結構並且可以進行並行計算,但卻丟失了句子之前的順序關係,爲了解決這個問題,引入了位置編碼(Position Encoding)。如下圖所示
假設位置變量有4維,則位置編碼的過程爲:
具體的計算公式爲
設計的思想是考慮的單詞的的絕對位置和相對位置,大意和和這兩個公式有關,具體爲什麼筆者沒有深究,大家有興趣可以瞭解下。
殘差結構與層歸一化
一個完整的encoder結構還包含兩個殘差結構(residual)和一層層歸一化(layer norm),代碼表示如下
class LayerNorm(nn.Module):
"Construct a layernorm module (See citation for details)."
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class SublayerConnection(nn.Module):
"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
def __init__(self, size, dropout):
super(SublayerConnection, self).__init__()
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
"Apply residual connection to any sublayer with the same size."
return x + self.dropout(sublayer(self.norm(x)))
可以發現其中的殘差結構 x + self.dropout(sublayer(self.norm(x))),如下圖所示
這種層結構在decoder中同樣存在,以一個兩層的encoder和decoder爲例,如圖所示
我們可以發現相比較於encoder,decoder中還包含了一個encoder-decoder attention結構。
decoder結構
從上一張圖我們可以看到,相比較於encoder層已有的Self-Attention模塊、前饋網絡(FFN)模塊、殘差結構和歸一化部分,decoder層多了一個Encoder-Decoder Attention模塊,那麼這個模塊是怎麼構成的呢?
我們先來看一個整個模型的輸入輸出過程
我們從圖中可以看到,encoder的輸出在decoder中是作爲K,V而存在的,這就是decoder中encoder-decoder層的來源:decoder前一層的輸入作爲當前層的Query,而Key和Value則來源於encoder層的輸出。代碼表示如下:
class DecoderLayer(nn.Module):
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 3)
def forward(self, x, memory, src_mask, tgt_mask):
"Follow Figure 1 (right) for connections."
m = memory
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
return self.sublayer[2](x, self.feed_forward)
注意代碼中的memory就是指encoder層的輸出,在decoder層中self.src_attn的參數分別爲Q,K,V,mask,mask的作用我們在self-attention結構中有提及過,它的存在是爲了使網絡只能獲取到當前時刻之前的輸入,即只對當前時刻 t 之前的時刻輸入進行attention計算。
我們從代碼中可以看出:
- 在decoder層的傳播過程中,輸入x先經過self-attention操作得到一箇中間變量x
- encoder-decoder層同樣是一個self-attention結構,以Q,K,V爲輸入
- 中間變量x作爲encoder-decoder層中的Query變量,而Key和Value則來源於encoder層的輸出memory,最後經過FFN計算得到decoder層的輸出。
總結
- Transormer結構是以self-attention結構作爲基礎,self-attention結構主要由Q,K,V三個輸入,Self-Attention的核心在於學習序列中其他部分對於該部分的權重值
- encoder最頂層的輸入是embedding得到的向量X,X通過三個得到Q,K,V,之後經過Attention操作後得到注意力矩陣
,的作用在於放縮 - Multihead attention結構的主要作用在於:提高了模型的表達能力,一組Q,K,V可以只捕捉句中單詞的一組相關關係,也就是“表達子空間”,多組Q,K,V則可以捕捉句中單詞的多組相關關係,並且這種計算是可以並行加速的。
- 殘差結構、FFN層和layer-Norm稍微關注一下~
- 基於上述結構的self-attention表達能力強、可以並行,但是丟失了位置信息,位置編碼部分彌補了這一缺點,但這種方式是否可以改進值得商榷
- decoder相比於encoder層,多了一個 encoder-decoder,它同樣是一個self-attention結構,以Q,K,V爲輸入,Key和Value來源於encoder層的輸出,而Query則來源於self-attention層的輸出。
最後大家可以欣賞一下transformer在翻譯任務的整個操作過程。
參考資料:
模型部分:http://jalammar.github.io/illustrated-transformer/
代碼部分:http://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks