可以直接看最下面的例子,再回頭看前面的解釋,就很明白了。
在pytorch
中,常見的拼接函數主要是兩個,分別是:
stack()
cat()
一般torch.cat()
是爲了把函數torch.stack()
得到tensor
進行拼接而存在的。參考鏈接torch.stack(),但是本文主要說cat()
。
torch.cat()
和python
中的內置函數cat()
, 在使用和目的上,是沒有區別的。
1. cat()
函數目的: 在給定維度上對輸入的張量序列seq 進行連接操作。
outputs = torch.cat(inputs, dim=0) → Tensor
參數
- inputs : 待連接的張量序列,可以是任意相同
Tensor
類型的python 序列 - dim : 選擇的擴維, 必須在
0
到len(inputs[0])
之間,沿着此維連接張量序列。
2. 重點
- 輸入數據必須是序列,序列中數據是任意相同的
shape
的同類型tensor
- 維度不可以超過輸入數據的任一個張量的維度
3.舉例子
- 準備數據,每個的
shape
都是[2,3]
# x1
x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
x1.shape # torch.Size([2, 3])
# x2
x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
x2.shape # torch.Size([2, 3])
- 合成
inputs
'inputs爲2個形狀爲[2 , 3]的矩陣 '
inputs = [x1, x2]
print(inputs)
'打印查看'
[tensor([[11, 21, 31],
[21, 31, 41]], dtype=torch.int32),
tensor([[12, 22, 32],
[22, 32, 42]], dtype=torch.int32)]
3.查看結果, 測試不同的dim
拼接結果
In [1]: torch.cat(inputs, dim=0).shape
Out[1]: torch.Size([4, 3])
In [2]: torch.cat(inputs, dim=1).shape
Out[2]: torch.Size([2, 6])
In [3]: torch.cat(inputs, dim=2).shape
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
大家可以複製代碼運行一下就會發現其中規律了。
總結
通常用來,把torch.stack
得到tensor
進行拼接而存在的。