Einsum 是幹嘛的?
使用愛因斯坦求和約定,可以以簡單的方式表示許多常見的多維線性代數數組運算。舉個栗子:給定兩個矩陣A和B,我們想對它們做一些操作,比如 multiply、sum或者transpose。雖然numpy裏面有可以直接使用的接口,能夠實現這些功能,但是使用enisum可以做的更快、更節省空間。比如:
A = np.array([0, 1, 2])
B = np.array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
我們想計算A和B的element-wise乘積,然後按行求和。如果不使用einsum接口,需要先對A做reshape到和B一樣的形狀,創建一個臨時的數組A[:, np.newaxis],然後在做乘積並按行求和:
(A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])
使用einsum接口,不需要創建reshape後的臨時數組,只是簡單地對行中的乘積求和,這樣會加速三倍:
np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])
如何使用 einsum
使用einsum的關鍵是,正確地labelling(標記)輸入數組和輸出數組的axes(軸)。我們可以使用字符串(比如:ijk,這種表示方式更常用)或者一個整數列表(比如:[0,1])來標記axes。 再來舉個栗子:爲了實現矩陣乘,我們可以這麼寫
np.einsum('ij,jk->ik', A, B)
字符串'ij,jk->ik'可以根據'->'的位置來切分,左邊的部分('ij,jk')標記了輸入的axes,右邊的('ik')標記了輸出的axes。輸入標記又根據','的位置進行切分,'ij'標記了第一個輸入A的axes,'jk'標記了第二個輸入B的axes。'ij'、'jk'的字符長度都是2,對應着A和B爲2D數組,'ik'的長度也爲2,因此輸出也是2D數組。
給定輸入:
A = np.array([[1, 1, 1],
[2, 2, 2],
[5, 5, 5]])
B = np.array([[0, 1, 0],
[1, 1, 0],
[1, 1, 1]])
np.einsum('ij,jk->ik', A, B)可以看作是:
- 在輸入數組的標記之間,重複字母表示沿這些軸的值將相乘,這些乘積構成輸出數組的值。比如圖中沿着j軸做乘積。
- 從輸出標記中省略的字母表示沿該軸的值將被求和。比如圖中的輸出沒有包含j軸,因此沿着j軸求和得到了輸出數組中的每一項。
- 如果輸出的標記是'ijk',那麼會得到一個 3x3x3 的矩陣。
- 輸出標記是'ik'的時候,並不會創建中間的 3x3x3 的矩陣,而是直接將總和累加到2D數組中。
- 如果輸出的標記是'ijk',那麼會得到一個 3x3x3 的矩陣。
-
- 如果輸出的標記是空,那麼輸出整個矩陣的和。
- 我們可以按任意順序返回不求和的軸。
我們將不指定'->'和輸出標記稱爲 explicit mode。 如果不指定'->'和輸出標記,numpy會將輸入標記中只出現一次的標記按照字母表順序,作爲輸出標記(也就是 implicit mode,後面會詳細介紹)。
'ij,jk->ik' 等價於 'ij,jk'
在explicit mode中,我們可以指定輸出標記的順序,比如:'ij,jk->ki'表示對矩陣乘做轉置。
Einsum 中的常用
對應的 einsum調用方式:
- 向量操作:A、B均爲向量
- 向量操作:A、B均爲2D矩陣
注意
- einsum求和時不提升數據類型,如果使用的數據類型範圍有限,可能會得到意外的錯誤:
a = np.ones(300, dtype=np.int8)
print(np.sum(a)) # correct result
print(np.einsum('i->', a)) # produces incorrect result
300
44
- einsum 在implicit mode可能不會按預期的順序排列軸
M = np.arange(24).reshape(2,3,4)
print(np.einsum('kij', M).shape) # 不是預期
print(np.einsum('ijk->kij', M).shape) #符合預期
(3, 4, 2)
(4, 2, 3)
np.einsum('kij', M) 實際上等價於 np.einsum('kij->ijk', M),因爲 implicit mode 下,einsum會認爲根據輸入標記,按照字母表順序排序,作爲輸出標記。
- 最後,einsum 也不總是numpy中的最快的選擇。
dot和inner函數之類的功能通常會鏈接到BLAS庫方法,性能可能勝過einsum。 還有tensordot函數。 在多個輸入數組上進行操作時,einsum似乎很慢。
看到這裏就基本滿足常用的要求啦,如果想深入瞭解大把細節,可以越過華麗麗的分割線,勇往直前!
------------------------------ 華麗麗的分割線 ------------------------------
numpy 中的 einsum
import numpy as np
# 給定兩個向量,下面要用到
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
# 給定兩個矩陣,下面要用到
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])
接口定義:
numpy.einsum(subscripts, *operands, out=None, dtype=None, order='K', casting='safe', optimize=False)
參數:
(xxx表示可選參數中的默認值)
- subscripts : str,指定求和的操作,可以是多個,用逗號分開。除非包含顯式指示符“->”以及精確輸出形式的下標標籤,否則將執行隱式(經典的愛因斯坦求和)計算。
- operands : list of array_like,輸入。
- out : ndarray, optional,指定輸出。
- dtype : {data-type, None}, optional,可以指定運算的數據類型,可能需要用戶提供類型轉換接口。默認爲None。
- order : {‘C’, ‘F’, ‘A’, ‘K’}, optional,控制輸出的內存佈局。'C'=contiguous;'F'=Fortran contiguous;'A'表示輸入爲'F'時輸出爲'F',否則輸出爲'C';'K'表示輸出的layout應該儘可能和輸入一致。
- casting : {‘no’, ‘equiv’, ‘safe’, ‘same_kind’, ‘unsafe’}, optional,指定可能發生的數據類型轉換,不推薦使用'unsafe'。
- ‘no’ 表示不做數據類型轉換。
- ‘equiv’ 表示僅允許字節順序更改。
- ‘safe’ 表示只允許保留值的強制類型轉換。
- ‘same_kind’ 表示僅允許安全類型轉換或同一類型(例如float64到float32)內的類型轉換。
- ‘unsafe’ 表示可以進行任何數據轉換。
- optimize : {False, True, ‘greedy’, ‘optimal’}, optional,控制是否進行中間優化。默認False,不做優化;設爲True則使用greedy算法。還接受np.einsum_path函數的提供的列表。
Returns:
- output : ndarray
在implicit模式下,選擇的下標很重要,因爲輸出按字母順序重新排序。例如,在二維矩陣中:
np.einsum('ij,jh', A, B) # 返回的是矩陣乘的轉置,因爲'h'本來應該是在'i'的後面,但是這裏反序了
array([[19, 43],
[22, 50]])
相比,在explicit模式下
np.einsum('ij,jh->ih', A, B) # 指定了輸出下標標籤的順序,因此效果等價於矩陣乘法 np.matmul(A,B)
# 相比之下,implicit 模式的 np.einsum('ij,jh', A, B) 效果等價於矩陣乘的轉置
array([[19, 43],
[22, 50]])
einsum 默認不會 broadcast,需要指定省略號(...)來啓用。默認的NumPy樣式廣播是通過在每項的左側添加省略號。
另一種使用 enisum 的方式是:einsum(op0, sublist0, op1, sublist1, ..., [sublistout])。如果沒有指定輸出格式,將以 implicit 模式計算,否則將在 explicit 模式執行。
np.einsum(A, [0,0])
5
optimize參數優化 einsum 表達式的收縮順序,對於具有三個或更多操作數的收縮,這可以大大增加計算效率,但需要在計算過程中增加內存佔用量。
來一些對比:
- 求矩陣的跡:
print(np.einsum('ii', A)) # implicit mode
print(np.einsum(A, [0,0]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 implicit mode
print(np.trace(A))
5
5
5
- 矩陣的對角元素
print(np.einsum('ii->i', A)) # explicit mode
print(np.einsum(A, [0,0], [0]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
print(np.diag(A))
[1 4]
[1 4]
[1 4]
- 對指定維度求和
print(np.einsum('ij->i', A)) # explicit mode
print(np.einsum(A, [0,1], [0]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
print(np.sum(A, axis=1))
[3 7]
[3 7]
[3 7]
對於高維的數組,可以使用省略號來對指定軸求和
print(np.einsum('...j->...', A)) # explicit mode
print(np.einsum(A, [Ellipsis,1], [Ellipsis])) # einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
[3 7]
[3 7]
- 計算矩陣轉置,或根據指定的 axes 調整矩陣
print(np.einsum('ji', A)) # implicit mode
print(np.einsum('ij->ji', A))# explicit mode
print(np.einsum(A, [1,0])) # einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 implicit mode
print(np.einsum(A, [1,0], [0,1])) # einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
print(np.transpose(A, axes=[1,0]))
[[1 3]
[2 4]]
[[1 3]
[2 4]]
[[1 3]
[2 4]]
[[1 3]
[2 4]]
[[1 3]
[2 4]]
- 向量內積
print(np.einsum('i,i', a, a)) # implicit mode
print(np.einsum(a, [0], b, [0]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 implicit mode
print(np.inner(a, a))
14
32
14
- 矩陣向量乘積
k = np.array([1,2])
print(np.einsum('ij,j', A, k)) # implicit mode
print(np.einsum(A, [0,1], k, [1]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 implicit mode
print(np.einsum('ij,j->i', A, k)) # explicit mode
print(np.einsum('...j,j->...', A, k)) # explicit mode with broadcast
print(np.dot(A, k))
[ 5 11]
[ 5 11]
[ 5 11]
[ 5 11]
[ 5 11]
- Broadcasting
print(np.einsum('...,...', 3, A)) # implicit mode
print(np.einsum(',ij', 3, A))# explicit mode
print(np.einsum(3, [Ellipsis], A, [Ellipsis]))
print(np.multiply(3, A))
[[ 3 6]
[ 9 12]]
[[ 3 6]
[ 9 12]]
[[ 3 6]
[ 9 12]]
[[ 3 6]
[ 9 12]]
- Tensor Contraction
m = np.arange(60.).reshape(3,4,5)
n = np.arange(24.).reshape(4,3,2)
print(np.einsum('ijk,jil->kl', m, n)) # explicit mode
print(np.einsum(a, [0,1,2], b, [1,0,3], [2,3])) # einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
print(np.tensordot(a,b, axes=([1,0],[0,1]))) #
[[4400. 4730.]
[4532. 4874.]
[4664. 5018.]
[4796. 5162.]
[4928. 5306.]]
[[4400. 4730.]
[4532. 4874.]
[4664. 5018.]
[4796. 5162.]
[4928. 5306.]]
[[4400. 4730.]
[4532. 4874.]
[4664. 5018.]
[4796. 5162.]
[4928. 5306.]]
- 鏈式數組操作: For more complicated contractions, speed ups might be achieved by repeatedly computing a ‘greedy’ path or pre-computing the ‘optimal’ path and repeatedly applying it, using an einsum_path insertion. Performance improvements can be particularly significant with larger arrays. 待補充
TODO: Tensorflow 中的 einsum; Pytorch 中的 einsum
參考: