PyTorch 中的 tensordot 以及 einsum 函數介紹
前言
最近發現這兩個函數用得越來越頻繁, 比如在 DCN 網絡的實現中就用到了(詳見 Deep Cross Network (深度交叉網絡, DCN) 介紹與代碼分析), 但是過段時間又忘記這兩個函數到底實現啥功能, 趁着現在印象還比較深刻的時候記錄一下 😂😂😂.
從例子出發
拿一大串中文或英文來形容這兩個函數, 該懵逼的還是懵逼, 從例子出發可以很容易理解它們的具體功能. 例子來自 Stackoverflow: product-of-pytorch-tensors-along-arbitrary-axes.
import torch
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0], [0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.einsum("ijk,jil->kl", (a, b))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)
從例子中可以發現, einsum
同樣可以實現 tensordot
的功能. 但現在的問題是, c
等於
[[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
這個結果具體是怎麼得到的 ?
在 numpy.tensordot 文檔中, 對 tensordot
的功能解釋爲:
Compute tensor dot product along specified axes
上面代碼中, 在計算 c
時, 指定了 axes
:
c = np.tensordot(a, b, axes=([1,0], [0,1]))
其中 a
用來參與計算的軸爲 [1, 0]
, 由於 a.shape = (3, 4, 3)
, 那麼用來參與計算的子數組 A
大小爲 (4, 3)
;
對於 b
來說, 用來參與計算的軸爲 [0, 1]
, 由於 b.shape = (4, 3, 2)
, 那麼用來參與計算的子數組 B
大小爲 (4, 3)
;
最後進行子數組(tensor)間的 dot product, 即 sum(A * B)
, 得到一個 scalar, 注意 *
是 element-wise 的乘法, 而不是矩陣乘法. 經過 tensordot
後, a
還保留着第 3 個維度, 大小爲 a.shape[2] = 3
, 而 b
也保留着第 3 個維度, 大小爲 b.shape[2] = 2
, 此時 c
的大小爲 (a.shape[2], b.shape[2]) = (3, 2)
.
經過以上分析, 我們現在換種思路來計算 c
, 代碼如下:
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
aa = a.transpose((2, 1, 0)) ## aa.shape = (3, 4, 3)
bb = b.transpose((2, 0, 1)) ## bb.shape = (2, 4, 3)
print(np.sum(aa[0], bb[0]))
# 2640.0
cc = [
[np.sum(aa[0] * bb[0]), np.sum(aa[0] * bb[1])],
[np.sum(aa[1] * bb[0]), np.sum(aa[1] * bb[1])],
[np.sum(aa[2] * bb[0]), np.sum(aa[2] * bb[1])]
]
print(cc)
[[2640.0, 2838.0], [2772.0, 2982.0], [2904.0, 3126.0]]
因此, tensordot 的作用是將 axes
指定的子數組進行點乘, axes
指定具體的維度.
einsum 的用法非常豐富, 下面參考資料中的例子無不顯示着這個函數的強大:
經過上面的分析, 可以發現 enisum
可以完成 tensordot
的功能, 即:
c = torch.einsum("ijk,jil->kl", (a, b))
用指定的字符串 "ijk,jil->kl"
就能形象地說明運算的目的.
靈魂畫手
再說明一下 transpose. 它和 reshape
不一樣, 我覺得它是改變觀看 tensor 的視角. 比如對於如下矩陣:
(需要一點空間想象 😂😂😂)
參考資料
- Stackoverflow: understanding-pytorch-einsum 例子相當豐富
- einsum滿足你一切需要:深度學習中的愛因斯坦求和約定 我只是想學習下中文表達~
- numpy.tensordot Numpy 的 tensordot 文檔
- Stackoverflow: understanding-tensordot 解釋的很通俗
- Stackoverflow: product-of-pytorch-tensors-along-arbitrary-axes 本文第一個例子的出處