PyTorch之愛因斯坦求和約定
網上關於這個函數:
torch.einsum
的介紹已經很多了,這裏列出我重點看過的一篇文章。
這篇文章寫的非常棒,很詳細。
這裏寫個簡單的例子,對於論文A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning
中的下面的式子,可以很方便的藉助該函數搞定。
先寫一個一般的思路。首先要注意PyTorch中的維度順序爲N, C, H, W
學習PyTorch的關鍵是要記住這個順序。對於原文來說,下表s、t、i、j分別表示長、寬、通道、通道,所以對於這裏提到的張量和各自實際上是對應於形狀爲[b, m, h, w]
和[b, n, h, w]
的。而這裏的累加符號是對於和進行的計算,所以實際上可以轉化爲矩陣乘法。[b, m, hxw] * [b, hxw, n] = [b, m, n]
先準備數據:
import torch
a = torch.rand(2, 3, 4, 5)
b = torch.rand_like(a)
一般的利用矩陣乘法的思路:
c = torch.bmm(a.view(2, 3, -1), b.view(2, 3, -1).transpose(1, 2)) / (4 * 5)
而當使用torch.einsum
的時候只需要一行:
d = torch.einsum("bist,bjst->bij", [a, b]) / (4 * 5)
驗證結果:
d == c
# output:
tensor([[[1, 1, 1],
[1, 1, 1],
[1, 1, 1]],
[[1, 1, 1],
[1, 1, 1],
[1, 1, 1]]], dtype=torch.uint8)
關於einsum
維度記憶的小技巧
以前面的torch.einsum("bist,bjst->bij", [a, b])
爲例。這裏einsum
的第一個參數表示的維度變換關係,也就是各個維度的自己的索引下標。一般只要滿足對應關係即可。即各個維度使用不同的下標,如果一樣,那就會一起進行加和計算(可以認爲是對於索引遍歷的過程彙總二者是同步的)。對於逗號的分隔表示對應於後面[]
(也可以不用[]
包裹,因爲第二部分參數使用的是一個可變長參數接收的)中的不同張量,也就是這裏的bist
對應於a
,而bjst
對應於b
。下標按照張量的對應維度調整好之後,就可以開始計算了。由於這裏計算的是針對s&t
的累加,最後s&t
都消除了,僅剩下來d
的b&i&j
三個索引。所以也就順其自然的寫出了這樣的變換關係:bist,bjst->bij
。相當的方便!
這裏實際上就是前面參考文章中的“2.11 張量縮約”一個例子。