pytorch下的unsqueeze和squeeze用法

#squeeze 函數:從數組的形狀中刪除單維度條目,即把shape中爲1的維度去掉

#unsqueeze() 是squeeze()的反向操作,增加一個維度,該維度維數爲1,可以指定添加的維度。例如unsqueeze(a,1)表示在1這個維度進行添加
 

import torch

a=torch.rand(2,3,1)             
print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1])

print(a.size())                 #torch.Size([2, 3, 1])
print(a.squeeze().size())       #torch.Size([2, 3])

print(a.squeeze(0).size())      #torch.Size([2, 3, 1])

print(a.squeeze(-1).size())     #torch.Size([2, 3])
print(a.size())                 #torch.Size([2, 3, 1])
print(a.squeeze(-2).size())     #torch.Size([2, 3, 1])
print(a.squeeze(-3).size())     #torch.Size([2, 3, 1])
print(a.squeeze(1).size())      #torch.Size([2, 3, 1])
print(a.squeeze(2).size())      #torch.Size([2, 3])
print(a.squeeze(3).size())      #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

print(a.unsqueeze().size())     #TypeError: unsqueeze() missing 1 required positional arguments: "dim"
print(a.unsqueeze(-3).size())   #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(-2).size())   #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(-1).size())   #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(0).size())    #torch.Size([1, 2, 3, 1])
print(a.unsqueeze(1).size())    #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(2).size())    #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(3).size())    #torch.Size([2, 3, 1, 1])
print(torch.unsqueeze(a,3))
b=torch.rand(2,1,3,1)
print(b.squeeze().size())       #torch.Size([2, 3])

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章