PyTorch之愛因斯坦求和約定

PyTorch之愛因斯坦求和約定


網上關於這個函數:torch.einsum的介紹已經很多了,這裏列出我重點看過的一篇文章。

這篇文章寫的非常棒,很詳細。

這裏寫個簡單的例子,對於論文A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning中的下面的式子,可以很方便的藉助該函數搞定。
在這裏插入圖片描述
先寫一個一般的思路。首先要注意PyTorch中的維度順序爲N, C, H, W學習PyTorch的關鍵是要記住這個順序。對於原文來說,下表s、t、i、j分別表示長、寬、通道、通道,所以對於這裏提到的張量Fs,t,i1F^1_{s,t,i}Fs,t,j2F^2_{s,t,j}各自實際上是對應於形狀爲[b, m, h, w][b, n, h, w]的。而這裏的累加符號是對於sstt進行的計算,所以實際上可以轉化爲矩陣乘法。[b, m, hxw] * [b, hxw, n] = [b, m, n]

先準備數據:

import torch
a = torch.rand(2, 3, 4, 5)
b = torch.rand_like(a)

一般的利用矩陣乘法的思路:

c = torch.bmm(a.view(2, 3, -1), b.view(2, 3, -1).transpose(1, 2)) / (4 * 5)

而當使用torch.einsum的時候只需要一行:

d = torch.einsum("bist,bjst->bij", [a, b]) / (4 * 5)

驗證結果:

d == c

# output:
tensor([[[1, 1, 1],
         [1, 1, 1],
         [1, 1, 1]],
        [[1, 1, 1],
         [1, 1, 1],
         [1, 1, 1]]], dtype=torch.uint8)

關於einsum維度記憶的小技巧

以前面的torch.einsum("bist,bjst->bij", [a, b])爲例。這裏einsum的第一個參數表示的維度變換關係,也就是各個維度的自己的索引下標。一般只要滿足對應關係即可。即各個維度使用不同的下標,如果一樣,那就會一起進行加和計算(可以認爲是對於索引遍歷的過程彙總二者是同步的)。對於逗號的分隔表示對應於後面[](也可以不用[]包裹,因爲第二部分參數使用的是一個可變長參數接收的)中的不同張量,也就是這裏的bist對應於a,而bjst對應於b。下標按照張量的對應維度調整好之後,就可以開始計算了。由於這裏計算的是針對s&t的累加,最後s&t都消除了,僅剩下來db&i&j三個索引。所以也就順其自然的寫出了這樣的變換關係:bist,bjst->bij。相當的方便!

這裏實際上就是前面參考文章中的“2.11 張量縮約”一個例子。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章