torch.stack()與torch.cat()

torch.stack()和torch.cat()的區別在於:

torch.stack()會造成一個新的維度,在該維度上進行拼接

torch.cat()不會造成新的維度,而是在已有維度上進行拼接

 

例子1:

import torch

a = torch.tensor([[1, 2, 3]])
b = torch.stack((a, a)).size()
c = torch.cat((a, a)).size()
print(b)
print(c)

結果爲:

torch.Size([2, 1, 3])
torch.Size([2, 3])

例子2: 

import torch

a = torch.tensor([1, 2, 3])
b = torch.stack((a, a)).size()
c = torch.cat((a, a)).size()
print(b)
print(c)

結果爲:

torch.Size([2, 3])
torch.Size([6])

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章