pytorch中torch.cat()函數:
功能:拼接兩個tensor。
用法:把兩個tensor A和B拼接在一起,可進行如下操作:
C = torch.cat( (A,B),0 ) #按維數0拼接(豎着拼)
C = torch.cat( (A,B),1 ) #按維數1拼接(橫着拼)
示例說明:
1)按維數0拼接
>>> import torch
>>> A=torch.ones(2,3) #2x3的張量(2行3列的矩陣)
>>> A
tensor([[ 1., 1., 1.],
[ 1., 1., 1.]])
>>> B=2*torch.ones(4,3) #4x3的張量(4行3列的矩陣)
>>> B
tensor([[ 2., 2., 2.],
[ 2., 2., 2.],
[ 2., 2., 2.],
[ 2., 2., 2.]])
>>> C=torch.cat((A,B),0) #按維數0(行)拼接
>>> C
tensor([[ 1., 1., 1.],
[ 1., 1., 1.],
[ 2., 2., 2.],
[ 2., 2., 2.],
[ 2., 2., 2.],
[ 2., 2., 2.]])
>>> C.size()
torch.Size([6, 3])
2)按維數1拼接
>>> D=2*torch.ones(2,4) #2x4的張量(2行4列的矩陣)
>>> C=torch.cat((A,D),1)#按維數1(列)拼接
>>> C
tensor([[ 1., 1., 1., 2., 2., 2., 2.],
[ 1., 1., 1., 2., 2., 2., 2.]])
>>> C.size()
torch.Size([2, 7])