einsum方法

einsum求矩陣運算

einsum是什麼

使用chatGPT解讀官方的文檔

Einsum允許使用基於Einstein求和約定的簡寫格式來計算許多常見的多維線性代數數組操作,這些操作可以表示爲一個方程式。
具體格式的細節在下面進行描述,但是一般的思想是使用一些下標爲輸入操作數的每個維度進行標記,並且定義哪些下標是輸出的一部分。
然後,通過沿着那些下標不是輸出的維度對操作數的元素積進行求和來計算輸出結果。
例如,可以使用einsum來計算矩陣乘法,如torch.einsum("ij,jk->ik", A, B)。
在這種情況下,j是求和下標,i和k是輸出下標(有關更多詳細信息,請參見下面的部分)。

因此描述的也比較清楚,總結而言就是:對每個維度進行標記,輸出需要維度的結果,未出現的維度就會作爲求和的維度被消除掉。

實現

下面引用了一下官網的例子,後面在多看看,多理解一下,還有點懵。

# trace
torch.einsum('ii', torch.randn(4, 4))

# diagonal
torch.einsum('ii->i', torch.randn(4, 4))

# outer product
x = torch.randn(5)
y = torch.randn(4)
torch.einsum('i,j->ij', x, y)

# batch matrix multiplication
As = torch.randn(3, 2, 5)
Bs = torch.randn(3, 5, 4)
torch.einsum('bij,bjk->bik', As, Bs)



# with sublist format and ellipsis
torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])



# batch permute
A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape

# equivalent to torch.nn.functional.bilinear
A = torch.randn(3, 5, 4)
l = torch.randn(2, 5)
r = torch.randn(2, 4)
torch.einsum('bn,anm,bm->ba', l, A, r)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章