本篇整理 Transformer 架構,及在 Transformer 基礎上衍生出來的 BERT 模型,最後給出
相應的應用案例。
1.Transformer的架構
Transformer 網絡架構架構由 Ashish Vaswani 等人在 Attention Is All You Need 一文中提出,並用於機器翻譯任務,和以往網絡架構有所區別的是,該網絡架構中,編碼器和解碼器沒有采用 RNN 或 CNN 等網絡架構,而是採用完全依賴於注意力機制的架構。網絡架構如下所示:
該網絡架構中引入了多頭注意力機制,該機制的網絡架構如下所示:
這裏有必要對多頭注意力機制進行一定的解釋。假設輸入數據的 batch size 爲B B B ,輸入數據的最大長度爲F F F ,輸出數據的最大長度爲T T T ,共有N N N 個注意力頭,每個注意力頭的輸出維度爲H H H ,則輸入/輸出數據中每個詞的 Embedding 的維度爲 E = N × H E=N×H E = N × H ,且注意力頭中每個頭對應的W Q , W K , W V \boldsymbol{W}^{Q}, \boldsymbol{W}^{K} ,\boldsymbol{W}^{V} W Q , W K , W V 矩陣均屬於R E × H \mathbb{R}^{E \times H} R E × H 。考慮到編碼器和解碼器涉及三個注意過程,且輸入有所不同,這裏分別來看。
2.編碼器自注意力
考慮輸入數據爲X ∈ R B × F × E \mathbf{X} \in \mathbb{R}^{B \times F \times E} X ∈ R B × F × E ,對輸入數據應用如下線性變換:
Q = X W Q , ( W Q ∈ R E × H ⇒ Q ∈ R B × F × H ) K = X W K , ( W K ∈ R E × H ⇒ K ∈ R B × F × H ) V = X W V , ( W V ∈ R E × H ⇒ V ∈ R B × F × H )
\begin{aligned}
&\mathbf{Q}=\mathbf{X} \mathbf{W}^{Q}, \quad\left(\mathbf{W}^{Q} \in \mathbb{R}^{E \times H} \Rightarrow \mathbf{Q} \in \mathbb{R}^{B \times F \times H}\right)\\
&\mathbf{K}=\mathbf{X} \mathbf{W}^{K}, \quad\left(\mathbf{W}^{K} \in \mathbb{R}^{E \times H} \Rightarrow \mathbf{K} \in \mathbb{R}^{B \times F \times H}\right)\\
&\mathbf{V}=\mathbf{X} \boldsymbol{W}^{V}, \quad\left(\boldsymbol{W}^{V} \in \mathbb{R}^{E \times H} \Rightarrow \mathbf{V} \in \mathbb{R}^{B \times F \times H}\right)
\end{aligned}
Q = X W Q , ( W Q ∈ R E × H ⇒ Q ∈ R B × F × H ) K = X W K , ( W K ∈ R E × H ⇒ K ∈ R B × F × H ) V = X W V , ( W V ∈ R E × H ⇒ V ∈ R B × F × H )
在上述變換基礎上進行如下計算,得到輸入中每個詞和自身及其他詞之間的關係權重
S = softmax ( Q K ⊤ H )
\mathbf{S}=\operatorname{softmax}\left(\frac{\mathbf{Q K}^{\top}}{\sqrt{H}}\right)
S = s o f t m a x ( H Q K ⊤ )
上述變換 K T K^T K T 表示對張量的最內部矩陣進行轉置,因此K ⊤ ∈ R B × H × F \mathbf{K}^{\top} \in \mathbb{R}^{B \times H \times F} K ⊤ ∈ R B × H × F ,Q K ⊤ \mathrm{QK}^{\top} Q K ⊤ 表示相同維度下張量 Q 和張量 K T K^T K T 最內部矩陣執行矩陣乘法運算 (即 numpy.matmul 運算),因此有S ∈ R B × F × F \mathbf{S} \in \mathbb{R}^{B \times F \times F} S ∈ R B × F × F ,該張量表示輸入數據中每個詞和自身及其他詞的關係權重,每一行的得分之和爲 1,即
∀ i , j np.sum ( S [ i , j , : ] ) = 1
\forall i, j \quad \operatorname{np.sum}(\mathbf{S}[i, j,:])=1
∀ i , j n p . s u m ( S [ i , j , : ] ) = 1
基於該得分即可得到,每個詞在當前上下文下的新的向量表示,公式如下:
x h = S V ⇒ X h ∈ R B × F × H
\mathbf{x}^{h}=\mathbf{S V} \quad \Rightarrow \mathbf{X}^{h} \in \mathbb{R}^{B \times F \times H}
x h = S V ⇒ X h ∈ R B × F × H
考慮到 Transformer 採用了 N 個注意力頭,因此最終產生了集合大小爲 N 的注意力集合{ X h 1 , … , X h N } \left\{\mathbf{X}^{h_{1}}, \ldots, \mathbf{X}^{h_{N}}\right\} { X h 1 , … , X h N } ,將該集合中中的所有張量按照最後一個維度進行拼接,並採用矩陣W O ∈ R E × E \boldsymbol{W}^{O} \in \mathbb{R}^{E \times E} W O ∈ R E × E 進行變換,得到最終生成的自注意力輸入數據,公式如下:
X a = numpy.concatenate ( ( X h 1 , … , X h N ) , axis = − 1 ) W O
\mathbf{X}^{a}=\text { numpy.concatenate }\left(\left(\mathbf{X}^{\mathrm{h}_{1}}, \ldots, \mathbf{X}^{\mathrm{h}_{\mathrm{N}}}\right), \text { axis }=-1\right) \boldsymbol{W}^{O}
X a = numpy.concatenate ( ( X h 1 , … , X h N ) , axis = − 1 ) W O
因此有X a ∈ R B × F × E \mathbf{X}^{a} \in \mathbb{R}^{B \times F \times E} X a ∈ R B × F × E 。
考慮到多頭注意力可以並行運算,爲了充分發揮向量化計算並行效率,實際實現中往往採用如下表示方案:
X par = reshape ( X , to shape = [ B × F , N × H ] ) Q p a r = X p a r W Q p a r ( W Q p a r ∈ R ( N × H ) × ( N × H ) ⇒ Q p a r ∈ R ( B × F ) × ( N × H ) ) K p a r = X p a r W K p a r ( W K p a r ∈ R ( N × H ) × ( N × H ) ⇒ K p a r ∈ R ( B × F ) × ( N × H ) ) V p a r = X p a r W V p a r ( W V p a r ∈ R ( N × H ) × ( N × H ) ⇒ V p a r ∈ R ( B × F ) × ( N × H ) )
\begin{aligned}
&\mathbf{X}^{\text {par }}=\text { reshape }(\mathbf{X}, \text { to shape }=[B \times F, N \times H])\\
&\begin{array}{ll}
{\mathbf{Q}^{p a r}=\mathbf{X}^{p a r} \boldsymbol{W}^{Q^{p a r}}} & {\left(\boldsymbol{W}^{Q^{p a r}} \in \mathbb{R}^{(N \times H) \times(N \times H)} \Rightarrow \mathbf{Q}^{p a r} \in \mathbb{R}^{(B \times F) \times(N \times H)}\right)} \\
{\mathbf{K}^{p a r}=\mathbf{X}^{p a r} \boldsymbol{W}^{K^{p a r}}} & {\left(\boldsymbol{W}^{K^{p a r}} \in \mathbb{R}^{(N \times H) \times(N \times H)} \Rightarrow \mathbf{K}^{p a r} \in \mathbb{R}^{(B \times F) \times(N \times H)}\right)} \\
{\text { V }^{p a r}=\mathbf{X}^{p a r} \boldsymbol{W}^{V^{p a r}}} & {\left(\boldsymbol{W}^{V^{p a r}} \in \mathbb{R}^{(N \times H) \times(N \times H)} \Rightarrow \mathbf{V}^{p a r} \in \mathbb{R}^{(B \times F) \times(N \times H)}\right)}
\end{array}
\end{aligned}
X par = reshape ( X , to shape = [ B × F , N × H ] ) Q p a r = X p a r W Q p a r K p a r = X p a r W K p a r V p a r = X p a r W V p a r ( W Q p a r ∈ R ( N × H ) × ( N × H ) ⇒ Q p a r ∈ R ( B × F ) × ( N × H ) ) ( W K p a r ∈ R ( N × H ) × ( N × H ) ⇒ K p a r ∈ R ( B × F ) × ( N × H ) ) ( W V p a r ∈ R ( N × H ) × ( N × H ) ⇒ V p a r ∈ R ( B × F ) × ( N × H ) )
在上述並行計算基礎上通過如下計算得到詞和自身及其他詞的關係權值:
v p a r = numpy.reshape (Vpar, ( B , F , N , H ) ) v’art = numpy. transpose (V p a r , axes = [ 0 , 2 , 1 , 3 ] ) ⇒ V p a r t ∈ R B × N × F × H X h = S p a r v p a r t ⇒ X h ∈ R B × N × F × H X h = numpy. transpose ( X h , axes = [ 0 , 2 , 1 , 3 ] ) ⇒ X h ∈ R B × F × N × H X h = numpy.reshape ( X h , F , N × H ) ) ⇒ X h ∈ R B × F × E
\begin{array}{l}
{\left.\mathbf{v}^{p a r}=\text { numpy.reshape (Vpar, }(B, F, N, H)\right)} \\
{ \text { v'art }\left.=\text { numpy. transpose (V }^{p a r}, \text { axes }=[0,2,1,3]\right) \Rightarrow \mathbf{V}^{p a r^{t}} \in \mathbb{R}^{B \times N \times F \times H}} \\
{\mathbf{X}^{h}=\mathbf{S}^{p a r} \mathbf{v}^{p a r^{t}} \Rightarrow \mathbf{X}^{h} \in \mathbb{R}^{B \times N \times F \times H}} \\
{\mathbf{X}^{h}=\text { numpy. transpose }\left(\mathbf{X}^{h}, \text { axes }=[0,2,1,3]\right) \Rightarrow \mathbf{X}^{h} \in \mathbb{R}^{B \times F \times N \times H}} \\
{\left.\mathbf{X}^{h}=\text { numpy.reshape }\left(\mathbf{X}^{h}, F, N \times H\right)\right) \Rightarrow \mathbf{X}^{h} \in \mathbb{R}^{B \times F \times E}}
\end{array}
v p a r = numpy.reshape (Vpar, ( B , F , N , H ) ) v’art = numpy. transpose (V p a r , axes = [ 0 , 2 , 1 , 3 ] ) ⇒ V p a r t ∈ R B × N × F × H X h = S p a r v p a r t ⇒ X h ∈ R B × N × F × H X h = numpy. transpose ( X h , axes = [ 0 , 2 , 1 , 3 ] ) ⇒ X h ∈ R B × F × N × H X h = numpy.reshape ( X h , F , N × H ) ) ⇒ X h ∈ R B × F × E
3.解碼器自注意力
解碼器的自注意力和編碼器的自注意力基本完全一致,需要注意的是解碼過程是one by one的生成過程,因此輸出數據中的每個詞在進行自注意力的過程時,僅可以看到當前輸出位置的所有前驅詞的信息,因此需要對輸出數據中的詞進行掩碼操作,該操作即對應上面的左圖上的掩碼操作。該掩碼操作相當於執行如下操作:
A = Q K ⊤ + M S p a r = softmax ( A H ) ⇒ S p a r ∈ R B × N × T × T
\begin{aligned}
\mathbf{A} &=\mathbf{Q K}^{\top}+\mathbf{M} \\
\mathbf{S}^{p a r} &=\operatorname{softmax}\left(\frac{\mathbf{A}}{\sqrt{H}}\right) \Rightarrow \mathbf{S}^{p a r} \in \mathbb{R}^{B \times N \times T \times T}
\end{aligned}
A S p a r = Q K ⊤ + M = s o f t m a x ( H A ) ⇒ S p a r ∈ R B × N × T × T
其中M ∈ R 1 × 1 × T × T \mathbf{M} \in \mathbb{R}^{1 \times 1 \times T \times T} M ∈ R 1 × 1 × T × T 爲掩碼,其最內部矩陣爲方陣,該方陣主對角線及以下元素均爲 0,主對角線以上元素爲− ∞ -\infty − ∞ 。譬如 T = 5 時,最內部方陣內容如下:
M = [ 0 − ∞ − ∞ − ∞ − ∞ 0 0 − ∞ − ∞ − ∞ 0 0 0 − ∞ − ∞ 0 0 0 0 − ∞ 0 0 0 0 0 ]
\boldsymbol{M}=\left[\begin{array}{ccccc}
{0} & {-\infty} & {-\infty} & {-\infty} & {-\infty} \\
{0} & {0} & {-\infty} & {-\infty} & {-\infty} \\
{0} & {0} & {0} & {-\infty} & {-\infty} \\
{0} & {0} & {0} & {0} & {-\infty} \\
{0} & {0} & {0} & {0} & {0}
\end{array}\right]
M = ⎣ ⎢ ⎢ ⎢ ⎢ ⎡ 0 0 0 0 0 − ∞ 0 0 0 0 − ∞ − ∞ 0 0 0 − ∞ − ∞ − ∞ 0 0 − ∞ − ∞ − ∞ − ∞ 0 ⎦ ⎥ ⎥ ⎥ ⎥ ⎤
其餘操作和編碼器自注意力機制一致,唯一不同的是此時需要向上面那樣將輸入數據換成 Y ∈ R N × T × E \mathbf{Y} \in \mathbb{R}^{N \times T \times E} Y ∈ R N × T × E ,因此所有的 F F F 均需換成 T T T 。
4.編碼解碼注意力
編碼解碼注意力和自注意力類似,唯一不同的是計算 Q, K, V 使用的數據有所區別,計算
Q 時採用 Y,計算 K 和 V 時採用 X,因此有:
X p a r = \mathbf{X}^{\mathrm{par}}= X p a r = reshape ( X , to shape = [ B × F , N × H ] ) (\mathbf{X}, \text { to shape }=[B \times F, N \times H]) ( X , to shape = [ B × F , N × H ] )
Y par = \mathbf{Y}^{\text {par }}= Y par = reshape ( Y , to shape = [ B × T , N × H ] ) (\mathbf{Y}, \text { to shape }=[B \times T, N \times H]) ( Y , to shape = [ B × T , N × H ] )
Q ende-par = Y par W Q ende-par ( W Q ende-par ∈ R ( N × H ) × ( N × H ) ) \mathbf{Q}^{\text {ende-par }} =\mathbf{Y}^{\text {par }} \boldsymbol{W}^{Q \text { ende-par }}\left(\boldsymbol{W}^{Q \text { ende-par }} \in \mathbb{R}^{(N \times H) \times(N \times H)}\right) Q ende-par = Y par W Q ende-par ( W Q ende-par ∈ R ( N × H ) × ( N × H ) )
K ende-par = X par W K coldeper ( W K ende-par ∈ R ( N × H ) × ( N × H ) ) \mathbf{K}^{\text {ende-par }} = \mathbf{X}^{\text {par }} \boldsymbol{W}^{K^{\text {coldeper }}} \quad\left(\boldsymbol{W}^{K \text { ende-par }} \in \mathbb{R}^{(N \times H) \times(N \times H)}\right) K ende-par = X par W K coldeper ( W K ende-par ∈ R ( N × H ) × ( N × H ) )
V ende-par = X par W Vpar ( W Vende-par ∈ R ( N × H ) × ( N × H ) ) \mathbf{V}^{\text {ende-par }}= \mathbf{X}^{\text {par }} \boldsymbol{W}^{\text {Vpar }} \quad\left(\boldsymbol{W}^{\text {Vende-par }} \in \mathbb{R}^{(N \times H) \times(N \times H)}\right) V ende-par = X par W Vpar ( W Vende-par ∈ R ( N × H ) × ( N × H ) )
因此有:
S ende-par = softmax ( Q ende-par t K ende-par t ⊤ H ) ⇒ S ende-par ∈ R B × N × T × F
\mathbf{S}^{\text {ende-par}}=\operatorname{softmax}\left(\frac{\mathbf{Q}^{\text {ende-par}^{t}} \mathbf{K}^{\text {ende-par}^{t \top}}}{\sqrt{H}}\right) \Rightarrow \mathbf{S}^{\text {ende-par}} \in \mathbb{R}^{B \times N \times T \times F}
S ende-par = s o f t m a x ( H Q ende-par t K ende-par t ⊤ ) ⇒ S ende-par ∈ R B × N × T × F
y ende-h = S ende-par V ende-par ⇒ Y ende-h ∈ R B × N × T × H
\mathbf{y}^{\text {ende-h}}=\mathbf{S}^{\text {ende-par}} \mathbf{V}^{\text {ende-par}} \Rightarrow \mathbf{Y}^{\text {ende-h}} \in \mathbb{R}^{B \times N \times T \times H}
y ende-h = S ende-par V ende-par ⇒ Y ende-h ∈ R B × N × T × H
其餘計算過程和編碼器自注意力機制類似。