Pytorch 之squeeze和unsqueeze用法

Pytorch使用中常會用到torch.squeeze()和torch.unsqueeze()函數:

查找資料相關記錄如下:

參考博客:https://blog.csdn.net/qq_39709535/article/details/81841426

1. torch.squeeze(input, dim = None, out = None): 返回一個tensor,當dim不設值時,去掉輸入的tensor的所有維度爲1的維度; 當dim爲某一整數(0<=dim<input.dim())時,判斷dim維的維度是否爲1,若是則去掉,否則不變。
另外,當input是一維的時候,squeeze不變

>>> x = torch.zeros(1,1,2,1,3)
>>> x.dim()
5
>>> torch.squeeze(x).size()#去掉dim=1的維度
torch.Size([2, 3])
>>> torch.squeeze(x,0).size()  # dim=0表示第一維,且第一維的維度爲1,所以去掉
torch.Size([1, 2, 1, 3])
>>> torch.squeeze(x,3).size()
torch.Size([1, 1, 2, 3])
>>> torch.squeeze(x,2).size()  # dim=2,第三維的維度爲2!=1,所以不變
torch.Size([1, 1, 2, 1, 3])


 

2. torch.unqueeze(input, dim, out=None): 和squeeze作用相反,unsqueeze()在dim維插入一個維度爲1的維,例如原來x是n×m維的,torch.unqueeze(x,0)這返回1×n×m的tensor


 

>>> x = torch.tensor([1,2,3])#dim=1,即(3)
>>> torch.unsqueeze(x,1)#變爲(3,1)的矩陣
tensor([[ 1],
        [ 2],

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章