Transformer原理及代碼註釋(Attention is all you need)

Transformer是谷歌針對NLP的機器翻譯問題,2017年發表了一篇名爲Attention Is All You Need 的論文中提出的模型。Transformer採用了機器翻譯中通用的encoder-decoder模型,但摒棄了以往模塊內部的RNN模型,只是完全依賴注意力機制來構建模型。其優點有以下幾點:

  • 結構簡單,拋棄RNN模型的優點在於沒有了時序的限制,RNN限制了數據必須按照輸入的順序處理前後有依賴性,所以在面對數據量大的時候,耗時會很長。但Transformer的self-attention機制使得其可以進行並行計算來加速
  • 每個單詞會考慮句子中所有詞對其的影響,一定程度上改善了RNN中由於句子過長帶來的誤差,Transformer的翻譯結果要比RNN好很多

下面會從原理和代碼來解讀Transformer模型:

1 Transformer 原理

首先按慣例上模型圖(
在這裏插入圖片描述
顯然其可以分成左右兩部分,爲了方便理解,我們把左邊叫做Encoders,右邊叫做Decoders。上圖只是模型的示意圖,實際上這兩個部分分別由六個圖示這樣的基本結構堆疊起來,像這樣:
在這裏插入圖片描述
爲了更好的理解,我們按照數據輸入之後在模型中的行走路線解釋模型的原理。

0 位置編碼

由於機器翻譯需要考慮詞序之間的關係,而且attention機制並沒有考慮詞序關係,所以我們要提前爲單詞加上位置編碼,使得模型可以利用輸入序列的順序信息。位置編碼的編碼規則如下所示:
在這裏插入圖片描述
如果我們emdedding的維度爲4,那麼示例可以像下圖(
在這裏插入圖片描述

1.1 Encoder

在這裏插入圖片描述
Encoder的作用是將輸入經過注意力機制和前饋神經網絡轉變成編碼,後期作爲輸入傳入Decoder解碼成另一種語言。輸入的字符串已經預先變成了詞嵌入矩陣形式(論文中使用的詞向量維數是512維),詞嵌入矩陣被輸入最底層的Encoder,然後將其拆分成向量輸入attention層進行計算,Attention層會輸出同樣是512維的向量列表,這兩個矩陣經過多頭Attention機制的整合,再進入前饋神經網絡,前饋神經網絡也輸出一個爲512維度的列表,然後將輸出傳到下一個Encoder。注意,每個Encoder模塊的前饋神經網絡都是獨立且結構相同的。(給並行創造條件)

1.1.1 Transformer的Attention機制

首先我們先來看Attention部分,模型的attention其實由兩部分組成:
在這裏插入圖片描述

1.1.1.1 Scaled Dot-Product Attention

首先是樸素的一看就不是並行的部分:D
在這裏插入圖片描述
計算self-attention首先從計算三個向量開始,對於每一個單詞,我們都需要三個向量:Query, Key, Value。這些向量是通過當前單詞與分別的訓練矩陣相乘得到的,維度自擬(這裏是64維)。另外,訓練矩陣在這裏假設是已經訓練好給定的,具體來源我們下一節再解釋。
然後有了材料我們就可以套公式了(霧):
在這裏插入圖片描述
首先我們用Q,K相乘得到的結果來相應單詞的得分,舉例如上上圖,然後將得分除以8,也就是sqrt(dk)sqrt(d_k),使得訓練過程中具有更穩定的梯度(論文中說:對於dkd_k很大的時候,點積得到的結果維度很大,使得結果處於softmax函數梯度很小的區域,這造成梯度很小,對反向傳播不利。爲了克服這個負面影響,除以一個縮放因子,可以一定程度上減緩這種情況???)。接下來再將輸出乘V過softmax,得到權值的向量,然後將其累加到詞向量中,產生此Attention層的輸出。通俗來講,公式大意是通過確定Q和K之間的相似程度來選擇V

1.1.1.2 Multi-Head Attention

通過論文的圖示,你一定看到了恍若虛影的東西,對,這就是玄學 可以並行計算的部分了。
在這裏插入圖片描述
公式如下:
在這裏插入圖片描述
多頭Attention提供並訓練了多個Q,K,V的訓練矩陣,他們用於將詞嵌入投影到不同的表示子空間(representation subspaces)中。通過此Attention層,我們爲每一個header都獨立維護了一套QKV訓練矩陣,在經過上一節的attention層處理之後,因爲我們有多個並行的attention,所以肯定會得到多個不同的Z矩陣,然後我們通過concat函數(將這幾個矩陣簡單相拼接)組合成一個大矩陣,之後與WOW^O相乘,過線性模型得到的結果就可以進入前饋神經網絡了。

下面是Attention過程的總結:
在這裏插入圖片描述

1.1.2 前饋神經網絡

這是一個Position-wise的前饋神經網絡,激活函數的順序是線性模型-RELU-線性模型:
在這裏插入圖片描述

1.1.3 layer-normalization

在這裏插入圖片描述
可以看到,詞向量除了喂入attention模型之外,還另外在喂入前饋神經網絡中與Z進行了整合。

1.2 Decoder

在這裏插入圖片描述
Decoder的結構與Encoder其實是非常像的,只是多了一層E-D Attention機制,爲了讓decoder捕獲輸入序列的位置信息。但是與Encoder不同,Decoder的每一次輸出都作爲下一次的時序的輸入,進入最底層的decoder:
在這裏插入圖片描述
另外,decoder的attention機制是按照輸出序列中出現比較早的位置來排序的,與亂序的encoder不同。

1.3 輸出

Decoder的輸出是一個浮點數的向量列表,我們需要再將其通過線性層和softmax纔可以將其變成輸出的單詞:
在這裏插入圖片描述

2 代碼註釋

'''
 code by Tae Hwan Jung(Jeff Jung) @graykode, Derek Miller @dmmiller612
 Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
             https://github.com/JayParks/transformer
'''
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt

dtype = torch.FloatTensor
# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E']

# Transformer Parameters
# Padding Should be Zero index
src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4}
src_vocab_size = len(src_vocab)

tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'S' : 5, 'E' : 6}
number_dict = {i: w for i, w in enumerate(tgt_vocab)}
tgt_vocab_size = len(tgt_vocab)

src_len = 5
tgt_len = 5

d_model = 512  # Embedding Size
d_ff = 2048 # FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_layers = 6  # number of Encoder of Decoder Layer
n_heads = 8  # number of heads in Multi-Head Attention

def make_batch(sentences):
   input_batch = [[src_vocab[n] for n in sentences[0].split()]]
   output_batch = [[tgt_vocab[n] for n in sentences[1].split()]]
   target_batch = [[tgt_vocab[n] for n in sentences[2].split()]]
   return Variable(torch.LongTensor(input_batch)), Variable(torch.LongTensor(output_batch)), Variable(torch.LongTensor(target_batch))

def get_sinusoid_encoding_table(n_position, d_model):
   def cal_angle(position, hid_idx):
       return position / np.power(10000, 2 * (hid_idx // 2) / d_model)
   def get_posi_angle_vec(position):
       return [cal_angle(position, hid_j) for hid_j in range(d_model)]

   sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
   sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
   sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1
   return torch.FloatTensor(sinusoid_table)

def get_attn_pad_mask(seq_q, seq_k):
   # print(seq_q)
   batch_size, len_q = seq_q.size()
   batch_size, len_k = seq_k.size()
   # eq(zero) is PAD token
   pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # batch_size x 1 x len_k(=len_q), one is masking
   return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

def get_attn_subsequent_mask(seq):
   attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
   subsequent_mask = np.triu(np.ones(attn_shape), k=1)
   subsequent_mask = torch.from_numpy(subsequent_mask).byte()
   return subsequent_mask


##Encoder attention-1
class ScaledDotProductAttention(nn.Module):
   def __init__(self):
       super(ScaledDotProductAttention, self).__init__()

   def forward(self, Q, K, V, attn_mask):
       scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
       # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
       scores.masked_fill_(attn_mask, -1e9)
       # Fills elements of self tensor with value where mask is one.
       attn = nn.Softmax(dim=-1)(scores)
       context = torch.matmul(attn, V)
       return context, attn


##Encoder attention-2
class MultiHeadAttention(nn.Module):
   def __init__(self):
       super(MultiHeadAttention, self).__init__()
       self.W_Q = nn.Linear(d_model, d_k * n_heads)
       self.W_K = nn.Linear(d_model, d_k * n_heads)
       self.W_V = nn.Linear(d_model, d_v * n_heads)
   def forward(self, Q, K, V, attn_mask):
       # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
       residual, batch_size = Q, Q.size(0)
       # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
       q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
       k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
       v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

       attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

       # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
       context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
       context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]
       output = nn.Linear(n_heads * d_v, d_model)(context)
       return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x len_q x d_model]


##前饋神經網絡 Position-wise版
class PoswiseFeedForwardNet(nn.Module):
   def __init__(self):
       super(PoswiseFeedForwardNet, self).__init__()
       self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
       self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)

   def forward(self, inputs):
       residual = inputs # inputs : [batch_size, len_q, d_model]
       output = nn.ReLU()(self.conv1(inputs.transpose(1, 2)))
       output = self.conv2(output).transpose(1, 2)
       return nn.LayerNorm(d_model)(output + residual)

#Encoder 基本模塊
class EncoderLayer(nn.Module):
   def __init__(self):
       super(EncoderLayer, self).__init__()
       self.enc_self_attn = MultiHeadAttention()
       self.pos_ffn = PoswiseFeedForwardNet()

   def forward(self, enc_inputs, enc_self_attn_mask):
       enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
       enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
       return enc_outputs, attn

#Decoder 基本模塊
class DecoderLayer(nn.Module):
   def __init__(self):
       super(DecoderLayer, self).__init__()
       self.dec_self_attn = MultiHeadAttention()
       self.dec_enc_attn = MultiHeadAttention()
       self.pos_ffn = PoswiseFeedForwardNet()

   def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
       dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
       dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
       dec_outputs = self.pos_ffn(dec_outputs)
       return dec_outputs, dec_self_attn, dec_enc_attn

class Encoder(nn.Module):
   def __init__(self):
       super(Encoder, self).__init__()
       self.src_emb = nn.Embedding(src_vocab_size, d_model)
       self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(src_vocab_size, d_model),freeze=True)
       self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])

   def forward(self, enc_inputs): # enc_inputs : [batch_size x source_len]
       enc_outputs = self.src_emb(enc_inputs) + self.pos_emb(torch.LongTensor([[1,2,3,4,0]]))
       enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)
       enc_self_attns = []
       for layer in self.layers:
           enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
           enc_self_attns.append(enc_self_attn)
       return enc_outputs, enc_self_attns

class Decoder(nn.Module):
   def __init__(self):
       super(Decoder, self).__init__()
       self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
       self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(tgt_vocab_size, d_model),freeze=True)
       self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

   def forward(self, dec_inputs, enc_inputs, enc_outputs): # dec_inputs : [batch_size x target_len]
       dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(torch.LongTensor([[5,1,2,3,4]]))
       dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs)
       dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
       dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)

       dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)

       dec_self_attns, dec_enc_attns = [], []
       for layer in self.layers:
           dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
           dec_self_attns.append(dec_self_attn)
           dec_enc_attns.append(dec_enc_attn)
       return dec_outputs, dec_self_attns, dec_enc_attns


##主模型
class Transformer(nn.Module):
   def __init__(self):
       super(Transformer, self).__init__()
       self.encoder = Encoder()
       self.decoder = Decoder()
       self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)
   def forward(self, enc_inputs, dec_inputs):
       enc_outputs, enc_self_attns = self.encoder(enc_inputs)
       dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
       dec_logits = self.projection(dec_outputs) # dec_logits : [batch_size x src_vocab_size x tgt_vocab_size]
       return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns

## 貪婪算法 模型損失函數和翻譯矩陣的訓練
###https://blog.csdn.net/qq_41664845/article/details/84969266
def greedy_decoder(model, enc_input, start_symbol):
   """
   For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
   target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
   Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
   :param model: Transformer Model
   :param enc_input: The encoder input
   :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
   :return: The target input
   """
   enc_outputs, enc_self_attns = model.encoder(enc_input)
   dec_input = torch.zeros(1, 5).type_as(enc_input.data)
   next_symbol = start_symbol
   for i in range(0, 5):
       dec_input[0][i] = next_symbol
       dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
       projected = model.projection(dec_outputs)
       prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
       next_word = prob.data[i]
       next_symbol = next_word.item()
   return dec_input

def showgraph(attn):
   attn = attn[-1].squeeze(0)[0]
   attn = attn.squeeze(0).data.numpy()
   fig = plt.figure(figsize=(n_heads, n_heads)) # [n_heads, n_heads]
   ax = fig.add_subplot(1, 1, 1)
   ax.matshow(attn, cmap='viridis')
   ax.set_xticklabels(['']+sentences[0].split(), fontdict={'fontsize': 14}, rotation=90)
   ax.set_yticklabels(['']+sentences[2].split(), fontdict={'fontsize': 14})
   plt.show()

model = Transformer()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(20):
   optimizer.zero_grad()
   enc_inputs, dec_inputs, target_batch = make_batch(sentences)
   outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
   loss = criterion(outputs, target_batch.contiguous().view(-1))
   print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
   loss.backward()
   optimizer.step()

# Test
greedy_dec_input = greedy_decoder(model, enc_inputs, start_symbol=tgt_vocab["S"])
predict, _, _, _ = model(enc_inputs, greedy_dec_input)
predict = predict.data.max(1, keepdim=True)[1]
print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])

print('first head of last state enc_self_attns')
showgraph(enc_self_attns)

print('first head of last state dec_self_attns')
showgraph(dec_self_attns)

print('first head of last state dec_enc_attns')
showgraph(dec_enc_attns)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章