一個關於pytorch的tensor點乘的小問題

事情的緣由是,坐旁邊的學姐有一段代碼將兩個二維數組相乘。特別的是,既不是點乘,也不是矩陣乘法,而是將各自每一行分別相乘再拼接得到一個三維數組,具體代碼大致如下

import torch
a = torch.Tensor(range(6)).reshape(2, 3)
b = torch.Tensor(range(1, 7)).reshape(2, 3)
batch = len(a)
length = len(a[0])
c = torch.zeros(batch, batch, length)
start = time.time()
for i in range(batch):
    for j in range(batch):
        c[i][j] = a[i] * b[j]

可以看到,實際上是使用a的每一行乘以b的每一行作爲c的一個元素,最終c是一個三維數組。

但問題在於,python解釋器的執行速度較慢,因此,我做了改進,將a和b分別按照維度1和維度進行擴展,再相乘,結果是一樣的。

a_batch = torch.stack([a] * batch, dim=1)
b_batch = torch.stack([b] * batch, dim=0)
c_batch = a_batch * b_batch

速度顯然快了很多,但是比較佔內存且麻煩。回頭看了一個學弟的代碼的解決方案,使用了None擴展便捷地解決了這個問題。

a_2 = a[:, None, :]
b_2 = b[None, :, :]
c_2 = a_2 * b_2

實際上,a_2和b_2的維度大小是在None那一維度爲1而不是我stack那樣的數個。輸出各自的維度如下

a.shape  torch.Size([2, 3])
a_batch.shape  torch.Size([2, 2, 3])
a_2.shape  torch.Size([2, 1, 3])
b.shape  torch.Size([2, 3])
b_batch.shape  torch.Size([2, 2, 3])
b_2.shape  torch.Size([1, 2, 3])

疑惑是,之前看的csdn博客,都說過pytorch的矩陣點乘需要兩個矩陣的維度相同,然而a_2和b_2爲何維度不同也能相乘呢?因此去查詢pytorch的官方文檔。即查看torch.mul的文檔https://pytorch.org/docs/stable/torch.html?highlight=mul#torch.mul,可見

可見,其實在維度不相同時,如果矩陣是可廣播的也可以相乘,查看broadcastable的定義 https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics

中文翻譯即爲

如果以下規則成立,則兩個張量是“可廣播的”:

  1. 每個張量至少有一個維度
  2. 當迭代尺寸時,從尾部尺寸開始,尺寸必須a相等,或者b其中一個尺寸爲1,或者c尺寸不存在

官方舉例

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension

# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist

# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3

所以,還是要多查看官方文檔,博客上很多是不全面的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章