PyTorch 中的 tensordot 以及 einsum 函數介紹

PyTorch 中的 tensordot 以及 einsum 函數介紹

前言

最近發現這兩個函數用得越來越頻繁, 比如在 DCN 網絡的實現中就用到了(詳見 Deep Cross Network (深度交叉網絡, DCN) 介紹與代碼分析), 但是過段時間又忘記這兩個函數到底實現啥功能, 趁着現在印象還比較深刻的時候記錄一下 😂😂😂.

從例子出發

拿一大串中文或英文來形容這兩個函數, 該懵逼的還是懵逼, 從例子出發可以很容易理解它們的具體功能. 例子來自 Stackoverflow: product-of-pytorch-tensors-along-arbitrary-axes.

import torch
import numpy as np

a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0], [0,1]))
print(c)
# [[ 2640.  2838.] [ 2772.  2982.] [ 2904.  3126.]]

a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.einsum("ijk,jil->kl", (a, b))
print(c)
# tensor([[ 2640.,  2838.], [ 2772.,  2982.], [ 2904.,  3126.]], dtype=torch.float64)

從例子中可以發現, einsum 同樣可以實現 tensordot 的功能. 但現在的問題是, c 等於

[[ 2640.  2838.] [ 2772.  2982.] [ 2904.  3126.]]

這個結果具體是怎麼得到的 ?

numpy.tensordot 文檔中, 對 tensordot 的功能解釋爲:

Compute tensor dot product along specified axes

上面代碼中, 在計算 c 時, 指定了 axes:

c = np.tensordot(a, b, axes=([1,0], [0,1]))

其中 a 用來參與計算的軸爲 [1, 0], 由於 a.shape = (3, 4, 3), 那麼用來參與計算的子數組 A 大小爲 (4, 3);
對於 b 來說, 用來參與計算的軸爲 [0, 1], 由於 b.shape = (4, 3, 2), 那麼用來參與計算的子數組 B 大小爲 (4, 3);
最後進行子數組(tensor)間的 dot product, 即 sum(A * B), 得到一個 scalar, 注意 * 是 element-wise 的乘法, 而不是矩陣乘法. 經過 tensordot 後, a 還保留着第 3 個維度, 大小爲 a.shape[2] = 3, 而 b 也保留着第 3 個維度, 大小爲 b.shape[2] = 2, 此時 c 的大小爲 (a.shape[2], b.shape[2]) = (3, 2).

經過以上分析, 我們現在換種思路來計算 c, 代碼如下:

import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
aa = a.transpose((2, 1, 0))  ## aa.shape = (3, 4, 3)
bb = b.transpose((2, 0, 1))  ## bb.shape = (2, 4, 3)
print(np.sum(aa[0], bb[0]))
# 2640.0
cc = [
	[np.sum(aa[0] * bb[0]), np.sum(aa[0] * bb[1])],
	[np.sum(aa[1] * bb[0]), np.sum(aa[1] * bb[1])],
	[np.sum(aa[2] * bb[0]), np.sum(aa[2] * bb[1])]
	]
print(cc)
[[2640.0, 2838.0], [2772.0, 2982.0], [2904.0, 3126.0]]

因此, tensordot 的作用是將 axes 指定的子數組進行點乘, axes 指定具體的維度.

einsum 的用法非常豐富, 下面參考資料中的例子無不顯示着這個函數的強大:

經過上面的分析, 可以發現 enisum 可以完成 tensordot 的功能, 即:

c = torch.einsum("ijk,jil->kl", (a, b))

用指定的字符串 "ijk,jil->kl" 就能形象地說明運算的目的.

靈魂畫手

再說明一下 transpose. 它和 reshape 不一樣, 我覺得它是改變觀看 tensor 的視角. 比如對於如下矩陣:
(需要一點空間想象 😂😂😂)
在這裏插入圖片描述

參考資料

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章