torch.stack()的官方解釋,詳解以及例子

可以直接看最下面的例子,再回頭看前面的解釋

pytorch中,常見的拼接函數主要是兩個,分別是:

  1. stack()
  2. cat()

實際使用中,這兩個函數互相輔助:關於cat()參考torch.cat(),但是本文主要說stack()

函數的意義:使用stack是爲了保留兩個信息:[1. 序列(先後)] 和 [2. 張量矩陣] 信息,而存在的擴張拼接函數。 常出現在自然語言處理(NLP)和圖像卷積神經網絡(CV)中。

1. stack()

官方解釋:沿着一個新維度對輸入張量序列進行連接。 序列中所有的張量都應該爲相同形狀。

淺顯說法:把多個2維的張量湊成一個3維的張量;多個3維的湊成一個4維的張量…以此類推,也就是在增加新的維度進行堆疊

outputs = torch.stack(inputs, dim=0) → Tensor

參數

  • inputs : 待連接的張量序列。
    注:python的序列數據只有listtuple

  • dim : 新的維度, 必須在0len(outputs)之間。
    注:len(outputs)是生成數據的維度大小,也就是outputs的維度值。

2. 重點

  1. 函數中的輸入inputs只允許是序列;且序列內部的張量元素,必須shape相等

----舉例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必須tensor_1.shape == tensor_2.shape

  1. dim是選擇生成的維度,必須滿足0<=dim<len(outputs)len(outputs)是輸出後的tensor的維度大小

不懂的看例子,再回過頭看就懂了。

3. 例子

按下面的三步:準備數據,合成inputs,查看結果。

1.準備數據,每個的shape都是[2,3]

# x1
x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
x1.shape # torch.Size([2, 3])
# x2
x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
x2.shape  # torch.Size([2, 3])
# x3
x3 = torch.tensor([[13,23,33],[23,33,43]],dtype=torch.int)
x3.shape # torch.Size([2,3])
# x4
x4 = torch.tensor([[14,24,34],[24,34,44]],dtype=torch.int)
x4.shape # torch.Size([2,3])

2.合成inputs

'inputs爲4個形狀爲[2 , 3]的矩陣 '
inputs = [x1, x2, x3, x4]
print(inputs)
# 打印看看結構。是4個張量
[tensor([[11, 21, 31],
         [21, 31, 41]], dtype=torch.int32),
 tensor([[12, 22, 32],
         [22, 32, 42]], dtype=torch.int32),
 tensor([[13, 23, 33],
         [23, 33, 43]], dtype=torch.int32),
 tensor([[14, 24, 34],
         [24, 34, 44]], dtype=torch.int32)]

3.查看結果, 測試不同的dim拼接結果

'選擇的 0<=dim<len()前三個 增加了新維度'
In    [1]: torch.stack(inputs, dim=0).shape
Out[1]: torch.Size([4, 2, 3]) 

In    [2]: torch.stack(inputs, dim=1).shape
Out[2]: torch.Size([2, 4, 3])

In    [3]: torch.stack(inputs, dim=2).shape
Out[3]: torch.Size([2, 3, 4])
'選擇的dim>len(outputs),所以報錯'
In    [4]: torch.stack(inputs, dim=3).shape
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

大家可以複製代碼運行一下就會發現:這個拼接後的維度大小4根據不同的dim一直變化。

dim shape
0 [4, 2, 3]
1 [2, 4, 3]
2 [2, 3, 4]
3 溢出報錯

4. 總結

  1. 函數作用:
    函數stack()序列數據內部的張量進行擴維拼接,指定維度由我們選擇、大小是生成後數據的維度區間。

  2. 存在意義:
    在自然語言處理和卷及神經網絡中, 通常爲了保留–[序列(先後)信息] 和 [張量的矩陣信息] 纔會使用stack

研究自然語言處理的同學一般知道,在循環神經網絡中,網絡的輸出數據通常是:包含了n個數據大小[batch_size, num_outputs]list,這個和[n, batch_size, num_outputs]是完全不一樣的!!!!不利於計算,需要使用stack進行拼接,保留–[1.時間步]和–[2.張量的矩陣乘積屬性]。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章