1. 拼接
(1). cat
注意要指出在哪個維度上進行拼接:
>>> import torch
>>> a = torch.rand(4,32,8)
>>> b = torch.rand(5,32,8)
>>> torch.cat([a,b],dim=0).shape
torch.Size([9, 32, 8])
且除了要拼接的維度外,其他維度數值必須保持一致,否則會報錯:
>>> import torch
>>> a = torch.rand(4,3,32,32)
>>> b = torch.rand(4,1,32,32)
>>> torch.cat([a,b],dim=0).shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Sizes of tensors must match except in dimension 0. Got 3 and 1 in dimension 1
(2). stack
會創建新的維度,所以在舊維度上必須完全一摸一樣:
>>> import torch
>>> a = torch.rand(32,8)
>>> b = torch.rand(32,8)
>>> torch.stack([a,b],dim=0).shape
torch.Size([2, 32, 8])
2. 拆分
(1). split
根據長度拆分
>>> import torch
>>> a = torch.rand(3,32,8)
>>> aa, bb = a.split([2,1],dim=0)
>>> aa.shape, bb.shape
(torch.Size([2, 32, 8]), torch.Size([1, 32, 8]))
>>> import torch
>>> a = torch.rand(2,32,8)
>>> aa,bb = a.split(1,dim=0)
>>> aa.shape,bb.shape
(torch.Size([1, 32, 8]), torch.Size([1, 32, 8]))
如果把2拆分成N塊,每塊的長度是2,則會報錯。
在理論上就是不拆分,也就是一個拆分成一塊,但在pytorch中不可以這樣做。
>>> import torch
>>> a = torch.rand(2,32,8)
>>> aa,bb = a.split(2,dim=0)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: not enough values to unpack (expected 2, got 1)
(2). chunk
按數量拆分:
就比較好理解,算除法就行。
>>> import torch
>>> a = torch.rand(8,32,8)
>>> aa,bb = a.chunk(2,dim=0)
>>> aa.shape,bb.shape
(torch.Size([4, 32, 8]), torch.Size([4, 32, 8]))