einsum初探

Einsum 是幹嘛的?

使用愛因斯坦求和約定,可以以簡單的方式表示許多常見的多維線性代數數組運算。舉個栗子:給定兩個矩陣A和B,我們想對它們做一些操作,比如 multiply、sum或者transpose。雖然numpy裏面有可以直接使用的接口,能夠實現這些功能,但是使用enisum可以做的更快、更節省空間。比如:

A = np.array([0, 1, 2])
B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

我們想計算A和B的element-wise乘積,然後按行求和。如果不使用einsum接口,需要先對A做reshape到和B一樣的形狀,創建一個臨時的數組A[:, np.newaxis],然後在做乘積並按行求和:

(A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])

使用einsum接口,不需要創建reshape後的臨時數組,只是簡單地對行中的乘積求和,這樣會加速三倍:

np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])

如何使用 einsum

使用einsum的關鍵是,正確地labelling(標記)輸入數組和輸出數組的axes(軸)。我們可以使用字符串(比如:ijk,這種表示方式更常用)或者一個整數列表(比如:[0,1])來標記axes。 再來舉個栗子:爲了實現矩陣乘,我們可以這麼寫

np.einsum('ij,jk->ik', A, B)

字符串'ij,jk->ik'可以根據'->'的位置來切分,左邊的部分('ij,jk')標記了輸入的axes,右邊的('ik')標記了輸出的axes。輸入標記又根據','的位置進行切分,'ij'標記了第一個輸入A的axes,'jk'標記了第二個輸入B的axes。'ij'、'jk'的字符長度都是2,對應着A和B爲2D數組,'ik'的長度也爲2,因此輸出也是2D數組。

給定輸入:

A = np.array([[1, 1, 1],
              [2, 2, 2],
              [5, 5, 5]])
B = np.array([[0, 1, 0],
              [1, 1, 0],
              [1, 1, 1]])
np.einsum('ij,jk->ik', A, B)可以看作是:

  • 在輸入數組的標記之間,重複字母表示沿這些軸的值將相乘,這些乘積構成輸出數組的值。比如圖中沿着j軸做乘積。
  • 從輸出標記中省略的字母表示沿該軸的值將被求和。比如圖中的輸出沒有包含j軸,因此沿着j軸求和得到了輸出數組中的每一項。
    • 如果輸出的標記是'ijk',那麼會得到一個 3x3x3 的矩陣。
      • 輸出標記是'ik'的時候,並不會創建中間的 3x3x3 的矩陣,而是直接將總和累加到2D數組中。

    • 如果輸出的標記是空,那麼輸出整個矩陣的和。
  • 我們可以按任意順序返回不求和的軸。

我們將不指定'->'和輸出標記稱爲 explicit mode。 如果不指定'->'和輸出標記,numpy會將輸入標記中只出現一次的標記按照字母表順序,作爲輸出標記(也就是 implicit mode,後面會詳細介紹)。

'ij,jk->ik' 等價於 'ij,jk'

在explicit mode中,我們可以指定輸出標記的順序,比如:'ij,jk->ki'表示對矩陣乘做轉置。

Einsum 中的常用

對應的 einsum調用方式:

  • 向量操作:A、B均爲向量

  • 向量操作:A、B均爲2D矩陣

注意

  • einsum求和時不提升數據類型,如果使用的數據類型範圍有限,可能會得到意外的錯誤:
a = np.ones(300, dtype=np.int8)
print(np.sum(a)) # correct result
print(np.einsum('i->', a)) # produces incorrect result
300
44
  • einsum 在implicit mode可能不會按預期的順序排列軸
M = np.arange(24).reshape(2,3,4)
print(np.einsum('kij', M).shape) # 不是預期
print(np.einsum('ijk->kij', M).shape) #符合預期
(3, 4, 2)
(4, 2, 3)
np.einsum('kij', M) 實際上等價於 np.einsum('kij->ijk', M),因爲 implicit mode 下,einsum會認爲根據輸入標記,按照字母表順序排序,作爲輸出標記。
  • 最後,einsum 也不總是numpy中的最快的選擇。
dot和inner函數之類的功能通常會鏈接到BLAS庫方法,性能可能勝過einsum。 還有tensordot函數。 在多個輸入數組上進行操作時,einsum似乎很慢。

看到這裏就基本滿足常用的要求啦,如果想深入瞭解大把細節,可以越過華麗麗的分割線,勇往直前!

------------------------------ 華麗麗的分割線 ------------------------------

numpy 中的 einsum

import numpy as np
# 給定兩個向量,下面要用到
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
# 給定兩個矩陣,下面要用到
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])

接口定義:

numpy.einsum(subscripts, *operands, out=None, dtype=None, order='K', casting='safe', optimize=False)

參數:

xxx表示可選參數中的默認值)

  • subscripts : str,指定求和的操作,可以是多個,用逗號分開。除非包含顯式指示符“->”以及精確輸出形式的下標標籤,否則將執行隱式(經典的愛因斯坦求和)計算。
  • operands : list of array_like,輸入。
  • out : ndarray, optional,指定輸出。
  • dtype : {data-type, None}, optional,可以指定運算的數據類型,可能需要用戶提供類型轉換接口。默認爲None。
  • order : {‘C’, ‘F’, ‘A’, ‘K’}, optional,控制輸出的內存佈局。'C'=contiguous;'F'=Fortran contiguous;'A'表示輸入爲'F'時輸出爲'F',否則輸出爲'C';'K'表示輸出的layout應該儘可能和輸入一致。
  • casting : {‘no’, ‘equiv’, ‘safe’, ‘same_kind’, ‘unsafe’}, optional,指定可能發生的數據類型轉換,不推薦使用'unsafe'。
    • ‘no’ 表示不做數據類型轉換。
    • ‘equiv’ 表示僅允許字節順序更改。
    • ‘safe’ 表示只允許保留值的強制類型轉換。
    • ‘same_kind’ 表示僅允許安全類型轉換或同一類型(例如float64到float32)內的類型轉換。
    • ‘unsafe’ 表示可以進行任何數據轉換。
  • optimize : {False, True, ‘greedy’, ‘optimal’}, optional,控制是否進行中間優化。默認False,不做優化;設爲True則使用greedy算法。還接受np.einsum_path函數的提供的列表。

Returns:

  • output : ndarray

在implicit模式下,選擇的下標很重要,因爲輸出按字母順序重新排序。例如,在二維矩陣中:

np.einsum('ij,jh', A, B) # 返回的是矩陣乘的轉置,因爲'h'本來應該是在'i'的後面,但是這裏反序了
array([[19, 43],
       [22, 50]])

相比,在explicit模式下

np.einsum('ij,jh->ih', A, B) # 指定了輸出下標標籤的順序,因此效果等價於矩陣乘法 np.matmul(A,B)
                            # 相比之下,implicit 模式的 np.einsum('ij,jh', A, B) 效果等價於矩陣乘的轉置
array([[19, 43],
       [22, 50]])

einsum 默認不會 broadcast,需要指定省略號(...)來啓用。默認的NumPy樣式廣播是通過在每項的左側添加省略號。

另一種使用 enisum 的方式是:einsum(op0, sublist0, op1, sublist1, ..., [sublistout])。如果沒有指定輸出格式,將以 implicit 模式計算,否則將在 explicit 模式執行。

np.einsum(A, [0,0])
5

optimize參數優化 einsum 表達式的收縮順序,對於具有三個或更多操作數的收縮,這可以大大增加計算效率,但需要在計算過程中增加內存佔用量。

來一些對比:

  1. 求矩陣的跡:
print(np.einsum('ii', A)) # implicit mode
print(np.einsum(A, [0,0]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 implicit mode
print(np.trace(A))
5
5
5
  1. 矩陣的對角元素
print(np.einsum('ii->i', A)) # explicit mode
print(np.einsum(A, [0,0], [0]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
print(np.diag(A))
[1 4]
[1 4]
[1 4]
  1. 對指定維度求和
print(np.einsum('ij->i', A)) # explicit mode
print(np.einsum(A, [0,1], [0]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
print(np.sum(A, axis=1))
[3 7]
[3 7]
[3 7]

對於高維的數組,可以使用省略號來對指定軸求和

print(np.einsum('...j->...', A)) # explicit mode
print(np.einsum(A, [Ellipsis,1], [Ellipsis])) # einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
[3 7]
[3 7]
  1. 計算矩陣轉置,或根據指定的 axes 調整矩陣
print(np.einsum('ji', A)) # implicit mode
print(np.einsum('ij->ji', A))# explicit mode
print(np.einsum(A, [1,0])) # einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 implicit mode
print(np.einsum(A, [1,0], [0,1])) # einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
print(np.transpose(A, axes=[1,0]))
[[1 3]
 [2 4]]
[[1 3]
 [2 4]]
[[1 3]
 [2 4]]
[[1 3]
 [2 4]]
[[1 3]
 [2 4]]
  1. 向量內積
print(np.einsum('i,i', a, a)) # implicit mode
print(np.einsum(a, [0], b, [0]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 implicit mode
print(np.inner(a, a))
14
32
14
  1. 矩陣向量乘積
k = np.array([1,2])
print(np.einsum('ij,j', A, k)) # implicit mode
print(np.einsum(A, [0,1], k, [1]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 implicit mode
print(np.einsum('ij,j->i', A, k)) # explicit mode
print(np.einsum('...j,j->...', A, k)) # explicit mode with broadcast
print(np.dot(A, k))
[ 5 11]
[ 5 11]
[ 5 11]
[ 5 11]
[ 5 11]
  1. Broadcasting
print(np.einsum('...,...', 3, A)) # implicit mode
print(np.einsum(',ij', 3, A))# explicit mode
print(np.einsum(3, [Ellipsis], A, [Ellipsis]))
print(np.multiply(3, A))
[[ 3  6]
 [ 9 12]]
[[ 3  6]
 [ 9 12]]
[[ 3  6]
 [ 9 12]]
[[ 3  6]
 [ 9 12]]
  1. Tensor Contraction
m = np.arange(60.).reshape(3,4,5)
n = np.arange(24.).reshape(4,3,2)
print(np.einsum('ijk,jil->kl', m, n)) # explicit mode
print(np.einsum(a, [0,1,2], b, [1,0,3], [2,3])) # einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
print(np.tensordot(a,b, axes=([1,0],[0,1]))) #
[[4400. 4730.]
 [4532. 4874.]
 [4664. 5018.]
 [4796. 5162.]
 [4928. 5306.]]
[[4400. 4730.]
 [4532. 4874.]
 [4664. 5018.]
 [4796. 5162.]
 [4928. 5306.]]
[[4400. 4730.]
 [4532. 4874.]
 [4664. 5018.]
 [4796. 5162.]
 [4928. 5306.]]
  1. 鏈式數組操作: For more complicated contractions, speed ups might be achieved by repeatedly computing a ‘greedy’ path or pre-computing the ‘optimal’ path and repeatedly applying it, using an einsum_path insertion. Performance improvements can be particularly significant with larger arrays. 待補充
     

TODO: Tensorflow 中的 einsum; Pytorch 中的 einsum

參考:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章