tf.einsum—愛因斯坦求和約定

1. einsum記法

如果你像我一樣,發現記住PyTorch/TensorFlow中那些計算點積、外積、轉置、矩陣-向量乘法、矩陣-矩陣乘法的函數名字和簽名很費勁,那麼einsum記法就是我們的救星。einsum記法是一個表達以上這些運算,包括複雜張量運算在內的優雅方式,基本上,可以把einsum看成一種領域特定語言。一旦你理解並能利用einsum,除了不用記憶和頻繁查找特定庫函數這個好處以外,你還能夠更迅速地編寫更加緊湊、高效的代碼。而不使用einsum的時候,容易出現引入不必要的張量變形或轉置運算,以及可以省略的中間張量的現象。此外,einsum這樣的領域特定語言有時可以編譯到高性能代碼,事實上,PyTorch最近引入的能夠自動生成GPU代碼併爲特定輸入尺寸自動調整代碼的張量理解(Tensor Comprehensions)就基於類似einsum的領域特定語言。此外,可以使用opt einsum和tf einsum opt這樣的項目優化einsum表達式的構造順序。

比方說,我們想要將兩個矩陣AϵRI×KA\epsilon\mathbb{R}^{I\times K}BϵRK×JB\epsilon\mathbb{R}^{K \times J}相乘,接着計算每列的和,最終得到向量cϵRJc \epsilon\mathbb{R}^J。使用愛因斯坦求和約定,這可以表達爲:
cj=ikAikBkj=AikBkj c_j=\sum_i\sum_kA_{ik}B_{kj}=A_{ik}B_{kj}

這一表達式指明瞭cc中的每個元素cic_i是如何計算的,列向量Ai:A_{i:}乘以行向量B:jB_{:j},然後求和。注意,在愛因斯坦求和約定中,我們省略了求和符號\sum,因爲我們隱式地累加重複的下標(這裏是k)和輸出中未指明的下標(這裏是i)。

在深度學習中,我經常碰到的一個問題是,變換高階張量到向量。例如,我可能有一個張量,其中包含一個batch中的N個訓練樣本,每個樣本是一個長度爲T的K維詞向量序列,我想把詞向量投影到一個不同的維度Q。如果將這個張量記作TϵRN×T×KT\epsilon\mathbb{R}^{N\times T\times K},將投影矩陣記作WϵRK×QW\epsilon\mathbb{R}^{K\times Q},那麼所需計算可以用einsumeinsum表達爲:
Cntq=kTntkWkq=TntkWkqC_{ntq}=\sum_kT_{ntk}W_{kq}=T_{ntk}W_{kq}

最後一個例子,比方說有一個四階張量TϵRN×T×K×MT\epsilon\mathbb{R}^{N\times T\times K\times M},我們想要使用之前的投影矩陣將第三維投影至QQ維,並累加第二維,然後轉置結果中的第一維和最後一維,最終得到張量CϵRM×Q×NC\epsilon \mathbb{R}^{M\times Q\times N}。einsum可以非常簡潔地表達這一切:
Cmqn=tkTntkmWkq=TntkmWkqC_{mqn}=\sum_t\sum_kT_{ntkm}W_{kq}=T_{ntkm}W_{kq}
注意,我們通過交換下標n和m(Cmqn而不是Cnqm),轉置了張量構造結果。

2. Numpy、PyTorch、TensorFlow中的einsum

einsum在numpy中實現爲np.einsum,在PyTorch中實現爲torch.einsum,在TensorFlow中實現爲tf.einsum,均使用一致的簽名einsum(equation, operands),其中equation是表示愛因斯坦求和約定的字符串,而operands則是張量序列(在numpy和TensorFlow中是變長參數列表,而在PyTorch中是列表)。

例如,我們的第一個例子,cj=ikAikBkj=AikBkj c_j=\sum_i\sum_kA_{ik}B_{kj}=A_{ik}B_{kj} 寫成equation字符串就是ik,kj -> j。注意這裏(i, j, k)的命名是任意的,但需要一致。

PyTorch和TensorFlow像numpy支持einsum的好處之一是einsum可以用於神經網絡架構的任意計算圖,並且可以反向傳播。典型的einsum調用格式如下:

在這裏插入圖片描述
上式中◻是佔位符,表示張量維度。上面的例子中,arg1和arg3是矩陣,arg2是二階張量,這一einsum運算的結果(result)是矩陣。注意einsum處理的是可變數量的輸入。在上面的例子中,einsum指定了三個參數之上的操作,但它同樣可以用在牽涉一個參數、兩個參數、三個以上參數的操作上。學習einsum的最佳途徑是通過學習一些例子,所以下面我們將展示一下,在許多深度學習模型中常用的庫函數,用einsum該如何表達(以PyTorch爲例)。

1 矩陣轉置
Bji=AijB_{ji}=A_{ij}

import torch
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->ji', [a])
tensor([[ 0.,  3.],
        [ 1.,  4.],
        [ 2.,  5.]])

2 求和
b=ijAij=Aijb=\sum_i\sum_jA_{ij}=A_{ij}

a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)

3 列求和
bj=iAij=Aijb_j=\sum_iA_{ij}=A_{ij}

a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3.,  5.,  7.])

4 行求和
bj=jAij=Aijb_j=\sum_jA_{ij}=A_{ij}

a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->i', [a])
tensor([  3.,  12.])

5 矩陣-向量相乘
ci=kAikbk=Aikbkc_i=\sum_kA_{ik}b_k=A_{ik}b_k

a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])
tensor([  5.,  14.])

6 矩陣-矩陣相乘
Cij=kAikBkj=AikBkjC_{ij}=\sum_kA_{ik}B_{kj}=A_{ik}B_{kj}

a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
torch.einsum('ik,kj->ij', [a, b])
tensor([[  25.,   28.,   31.,   34.,   37.],
        [  70.,   82.,   94.,  106.,  118.]])

7 點積
c=iaibi=aibic=\sum_ia_ib_i=a_ib_i

a = torch.arange(3)
b = torch.arange(3,6)  # [3, 4, 5]
torch.einsum('i,i->', [a, b])
tensor(14.)

8 哈達瑪積
Cij=AijBijC_{ij}=A_{ij}B_{ij}

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[  0.,   7.,  16.],
        [ 27.,  40.,  55.]])

9 外積
Cij=aibjC_{ij}=a_ib_j

a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b])
tensor([[  0.,   0.,   0.,   0.],
        [  3.,   4.,   5.,   6.],
        [  6.,   8.,  10.,  12.]])

10 batch矩陣相乘
Cijl=kAijkBikl=AijkBiklC_{ijl}=\sum_kA_{ijk}B_{ikl}=A_{ijk}B_{ikl}

a = torch.randn(3,2,5)

b = torch.randn(3,5,3)

torch.einsum('ijk,ikl->ijl', [a, b])

tensor([[[ 1.0886,  0.0214,  1.0690],

         [ 2.0626,  3.2655, -0.1465]],


        [[-6.9294,  0.7499,  1.2976],

         [ 4.2226, -4.5774, -4.8947]],


        [[-2.4289, -0.7804,  5.1385],

         [ 0.8003,  2.9425,  1.7338]]])

11 張量縮約
batch矩陣相乘是張量縮約的一個特例。比方說,我們有兩個張量,一個n階張量A ∈ ℝI1 × ⋯ × In,一個m階張量B ∈ ℝJ1 × ⋯ × Jm。舉例來說,我們取n = 4,m = 5,並假定I2 = J3且I3 = J5。我們可以將這兩個張量在這兩個維度上相乘(A張量的第2、3維度,B張量的3、5維度),最終得到一個新張量C ∈ ℝI1 × I4 × J1 × J2 × J4,如下所示:
Cpstuv=qrApqrsBtuqvr=ApqrsBtuqvrC_{pstuv}=\sum_q\sum_rA_{pqrs}B_{tuqvr}=A_{pqrs}B_{tuqvr}

a = torch.randn(2,3,5,7)

b = torch.randn(11,13,3,17,5)

torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape

torch.Size([2, 7, 11, 13, 17])

12 雙線性變換
Dij=klAikBjklCil=AikBjklCilD_{ij}=\sum_k\sum_lA_{ik}B_{jkl}C_{il}=A_{ik}B_{jkl}C_{il}

a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
torch.einsum('ik,jkl,il->ij', [a, b, c])
tensor([[ 3.8471,  4.7059, -3.0674, -3.2075, -5.2435],
        [-3.5961, -5.2622, -4.1195,  5.5899,  0.4632]])
3.總結

einsum是一個函數走天下,是處理各種張量操作的瑞士軍刀。話雖如此,“einsum滿足你一切需要”顯然誇大其詞了。從上面的真實用例可以看到,我們仍然需要在einsum之外應用非線性和構造額外維度(unsqueeze)。類似地,分割、連接、索引張量仍然需要應用其他庫函數。

使用einsum的麻煩之處是你需要手動實例化參數,操心它們的初始化,並在模型中註冊這些參數。不過我仍然強烈建議你在實現模型時,考慮下有哪些情況適合使用einsum.

from

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章