Merge or split
Cat
Statistics about scores
[class1-4,students,scores]
[class5-9,students,scores]
import torch
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
print(torch.cat([a,b],dim=0).shape)
out: torch.Size([9, 3, 32, 32])
a1 = torch.rand(4,3,32,32)
a2 = torch.rand(5,3,32,32)
print(torch.cat([a1,a2],dim=0).shape)
out: torch.Size([9, 3, 32, 32])
a1 = torch.rand(4,3,32,32)
a2 = torch.rand(4,1,32,32)
# print(torch.cat([a1,a2],dim=0).shape) 報錯
print(torch.cat([a1,a2],dim=1).shape)
out: 報錯:Sizes of tensors must match except in dimension 0, 是因爲維度dim=1的shape不一樣造成的,而cat維度的shape可以不一樣。
out: torch.Size([4, 4, 32, 32])
a1 = torch.rand(4,3,16,32)
a2 = torch.rand(4,3,16,32)
print(torch.cat([a1,a2],dim=2).shape)
out: torch.Size([4, 3, 32, 32])
Along distinct dim/axis
Stack
a1 = torch.rand(4,3,16,32)
a2 = torch.rand(4,3,16,32)
print(torch.stack([a1,a2],dim=2).shape) # torch.Size([4,3,2,16,32])
a = torch.rand(32,8)
b = torch.rand(32,8)
print(torch.stack([a,b],dim=0).shape) # torch.Size([2,32,8])
print(torch.cat([a,b],dim=0).shape) # torch.Size([64,8])
Stack與Cat最根本區別在於是聯合還是合併。舉一個簡單例子,把a看作班級32個同學8門課程的成績,把b看作另一個班級32個同學8門課程的成績,用Stack看作是兩個班級的聯合[2,32,8],而用cat看作兩個班級是一個整體[64,8]。對於Stack而言兩個維度都必須一致,而對於Cat而言拼接的那個維度可以不一樣
b = torch.rand([30,8])
# print(torch.stack([a,b],dim=0)) 報錯
print(torch.cat([a,b],dim=0).shape) # torch.Size([62, 8])
Split
a = torch.rand(32,8)
b = torch.rand(32,8)
c = torch.stack([a,b],dim=0)
print(c.shape) #torch.Size([64,8])
按長度拆分
長度不一樣,可以直接給定一個list,[1,1]切片,其實就是拆成2塊,每塊長度是1。
如果是給定[1,2,3]就代表拆成3塊,每塊長度分別是1,2,3
aa,bb = c.split([1,1],dim=0)
print(aa.shape,bb.shape) # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])
長度一樣就設一個固定值,每塊長度是1,拆分成n塊,
aa,bb = c.split(1,dim=0)
print(aa.shape,bb.shape) # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])
ValueError: not enough values to unpack (expected 2, got 1)
拆分成n塊,每塊長度是2,但是c只能拆成1個,所以返回1個tensor,不能用2個tensor接受
# aa,bb = c.split(2,dim=0) 報錯
Chunk
按數量拆分,拆分成2塊,每塊長度是2/2
aa,bb = c.chunk(2,dim=0)
print(aa.shape,bb.shape) # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])