various product
- dot product
- element-wise product
- inner product
- outer product
- Kronecker product(克羅內克積)
torch.enisum
Einstein notation
1 dot product
線代中的矩陣乘法,A.dot(B)
,表示A的行乘以B的列,A的列數需要等於B的行數。
A=⎝⎜⎜⎜⎛a11a21⋮am1a12a22⋮am2⋯⋯⋱⋯a1na2n⋮amn⎠⎟⎟⎟⎞,B=⎝⎜⎜⎜⎛b11b21⋮bn1b12b22⋮bn2⋯⋯⋱⋯b1pb2p⋮bnp⎠⎟⎟⎟⎞
C=⎝⎜⎜⎜⎛c11c21⋮cm1c12c22⋮cm2⋯⋯⋱⋯c1pc2p⋮cmp⎠⎟⎟⎟⎞
cij=ai1b1j+ai2b2j+⋯+ainbnj=k=1∑naikbkj,
2 element-wise product
矩陣的形狀需要一致,表示前後對應座標的元素相乘得到的矩陣
(A∘B)ij=(A⊙B)ij=(A)ij(B)ij.
⎣⎡a11a21a31a12a22a32a13a23a33⎦⎤∘⎣⎡b11b21b31b12b22b32b13b23b33⎦⎤=⎣⎡a11b11a21b21a31b31a12b12a22b22a32b32a13b13a23b23a33b33⎦⎤.
3 inner product
向量的內積,即點乘,結果是一個標量(scalar)。
⟨u,v⟩=uTv
4 outer product
向量的外積,即叉乘,結果是一個向量(vector)。
(u⊗v)ij=uivj
u⊗v=uvT=⎣⎢⎢⎡u1u2u3u4⎦⎥⎥⎤[v1v2v3]=⎣⎢⎢⎡u1v1u2v1u3v1u4v1u1v2u2v2u3v2u4v2u1v3u2v3u3v3u4v3⎦⎥⎥⎤.
5 Kronecker product(克羅內克積)
對應維度兩兩相乘,如下例子,a11分別與b11、b21相乘爲縱向前兩個元素,分別與b11、b12、b13相乘爲橫向前三個元素
6 愛因斯坦求和
用字母標記待操作的向量或矩陣的維度,並用相應的字母標記需要的結果,中間用箭頭連接。愛因斯坦求和相當於是一個矩陣處理的接口,可以計算矩陣的跡、對角矩陣、求和、點乘等等。“輸入標記中重複字母表示沿這些軸的值將相乘”,“輸出標記中省略字母表示沿該軸的值將被求和”
np.einsum('ij,jk->ik', A, B)
可以看作是:2D矩陣乘法:
np.einsum('ij,jk->ijk', A, B)
,注意此時的輸出是三維的ijk
,根據下圖推測出計算方式是:用第一個矩陣i
維度上的三個向量,即j
向量,分別乘第二個矩陣沿着j
方向的整個矩陣。
AB均是向量
AB均是矩陣
我自己的理解
A = np.arange(18).reshape(3,2,3)
B = np.arange(24).reshape(2,3,4)
x = np.einsum('ijk,jil->kl',A,B)
print(A,'\n',B,'\n',x)
"""
[[[ 0 1 2]
[ 3 4 5]]
[[ 6 7 8]
[ 9 10 11]]
[[12 13 14]
[15 16 17]]]
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
[[600 645 690 735]
[660 711 762 813]
[720 777 834 891]]
"""
還不明白的話,看我下面的圖幫助理解,我認爲矩陣問題,特別是高維tensor問題,一定要形象的畫出來才方便理解。
下面是一個代碼中的具體例子:
torch.manual_seed(1)
x = torch.randn(1, 2, 3,4)
A = torch.randn(3, 2)
print(x,'\n',A.T)
"""
tensor([[[[-1.5256, -0.7502, -0.6540, -1.6095],
[-0.1002, -0.6092, -0.9798, -1.6091],
[ 0.4391, 1.1712, 1.7674, -0.0954]],
[[ 0.1394, -1.5785, -0.3206, -0.2993],
[-0.7984, 0.3357, 0.2753, 1.7163],
[-0.0561, 0.9107, -1.3924, 2.6891]]]])
tensor([[ 3.5870, 1.5987, 0.3255],
[-1.8313, -1.2770, -0.4791]])
"""
x = torch.einsum('ncvl,vw->ncwl', x, A)
>>> x
"""
tensor([[[[-5.4895, -3.2838, -3.3369, -8.3767],
[ 2.7113, 1.5907, 1.6020, 5.0480]],
[[-0.7948, -4.8289, -1.1631, 2.5454],
[ 0.7913, 2.0256, 0.9027, -2.9320]]]])
"""
Reference:
- 小夕的知乎
- wikipedia
kipedia.org/wiki/Einstein_notation#Common_operations_in_this_notation)
- Baidu baike