torch.cat()函數的官方解釋,詳解以及例子

可以直接看最下面的例子,再回頭看前面的解釋,就很明白了。

pytorch中,常見的拼接函數主要是兩個,分別是:

  1. stack()
  2. cat()

一般torch.cat()是爲了把函數torch.stack()得到tensor進行拼接而存在的。參考鏈接torch.stack(),但是本文主要說cat()

torch.cat()python中的內置函數cat(), 在使用和目的上,是沒有區別的。

1. cat()

函數目的: 在給定維度上對輸入的張量序列seq 進行連接操作。

outputs = torch.cat(inputs, dim=0) → Tensor

參數

  • inputs : 待連接的張量序列,可以是任意相同Tensor類型的python 序列
  • dim : 選擇的擴維, 必須在0len(inputs[0])之間,沿着此維連接張量序列。

2. 重點

  1. 輸入數據必須是序列,序列中數據是任意相同的shape的同類型tensor
  2. 維度不可以超過輸入數據的任一個張量的維度

3.舉例子

  1. 準備數據,每個的shape都是[2,3]
# x1
x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
x1.shape # torch.Size([2, 3])
# x2
x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
x2.shape  # torch.Size([2, 3])
  1. 合成inputs
'inputs爲2個形狀爲[2 , 3]的矩陣 '
inputs = [x1, x2]
print(inputs)
'打印查看'
[tensor([[11, 21, 31],
         [21, 31, 41]], dtype=torch.int32),
 tensor([[12, 22, 32],
         [22, 32, 42]], dtype=torch.int32)]

3.查看結果, 測試不同的dim拼接結果

In    [1]: torch.cat(inputs, dim=0).shape
Out[1]: torch.Size([4,  3])

In    [2]: torch.cat(inputs, dim=1).shape
Out[2]: torch.Size([2, 6])

In    [3]: torch.cat(inputs, dim=2).shape
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

大家可以複製代碼運行一下就會發現其中規律了。

總結

通常用來,把torch.stack得到tensor進行拼接而存在的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章