pytorch的通道順序和caffe一致,是(B C H W)
下面函數例子均基於一個feature map: feat, 通道順序爲(B C H W)
1.permute
feat.permute(0,2,3,1),將原來序號爲0,1,2,3的通道調整爲0,2,3,1,對於feat來說這個操作就是將通道順序變爲tensorflow風格。
2.squeeze和unsqueeze
squeeze(dim)是對維度爲1的維度dim進行降維,如[3,2,1] -> [3,2]。
unsqueeze(dim)是在dim處插入維度爲1的維度。
import torch
a = torch.Tensor([[[1],[3]],[[5],[7]],[[9],[11]]])
a1 = a.squeeze(2)
a2 = a1.unsqueeze(0)
if __name__ == '__main__':
print(a.size())
print(a1.size())
print(a2.size())
輸出爲:
torch.Size([3, 2, 1])
torch.Size([3, 2])
torch.Size([1, 3, 2])
3.expand,repeat
首先都是用於一個維度的複製擴展。expand不佔用新空間,repeat要用。
import torch
a = torch.Tensor([[[1],[3]],[[5],[7]],[[9],[11]]])
a1 = a.expand([3,2,4])
a2 = a.repeat([3,2,4])
if __name__ == '__main__':
print(a.size())
print(a1)
print(a2)
輸出爲:
torch.Size([3, 2, 1])
tensor([[[ 1., 1., 1., 1.],
[ 3., 3., 3., 3.]],
[[ 5., 5., 5., 5.],
[ 7., 7., 7., 7.]],
[[ 9., 9., 9., 9.],
[ 11., 11., 11., 11.]]])
tensor([[[ 1., 1., 1., 1.],
[ 3., 3., 3., 3.],
[ 1., 1., 1., 1.],
[ 3., 3., 3., 3.]],
[[ 5., 5., 5., 5.],
[ 7., 7., 7., 7.],
[ 5., 5., 5., 5.],
[ 7., 7., 7., 7.]],
[[ 9., 9., 9., 9.],
[ 11., 11., 11., 11.],
[ 9., 9., 9., 9.],
[ 11., 11., 11., 11.]],
[[ 1., 1., 1., 1.],
[ 3., 3., 3., 3.],
[ 1., 1., 1., 1.],
[ 3., 3., 3., 3.]],
[[ 5., 5., 5., 5.],
[ 7., 7., 7., 7.],
[ 5., 5., 5., 5.],
[ 7., 7., 7., 7.]],
[[ 9., 9., 9., 9.],
[ 11., 11., 11., 11.],
[ 9., 9., 9., 9.],
[ 11., 11., 11., 11.]],
[[ 1., 1., 1., 1.],
[ 3., 3., 3., 3.],
[ 1., 1., 1., 1.],
[ 3., 3., 3., 3.]],
[[ 5., 5., 5., 5.],
[ 7., 7., 7., 7.],
[ 5., 5., 5., 5.],
[ 7., 7., 7., 7.]],
[[ 9., 9., 9., 9.],
[ 11., 11., 11., 11.],
[ 9., 9., 9., 9.],
[ 11., 11., 11., 11.]]])
4.view
view是一個很有意思的函數,總的來說,就是把整個高維數組按照數字出現的順序展開成一維,然後在不改變數字順序的情況下,按照view給出的shape生成新的數組。
import torch
a = torch.Tensor([[[1,2],[3,4]],[[5,6],[7,8]],[[9,10],[11,12]]])
a1 = a.view([4,3,1])
a2 = a.view([2,6])
if __name__ == '__main__':
print(a.size())
print(a1)
print(a2)
結果爲:
torch.Size([3, 2, 2])
tensor([[[ 1.],
[ 2.],
[ 3.]],
[[ 4.],
[ 5.],
[ 6.]],
[[ 7.],
[ 8.],
[ 9.]],
[[ 10.],
[ 11.],
[ 12.]]])
tensor([[ 1., 2., 3., 4., 5., 6.],
[ 7., 8., 9., 10., 11., 12.]])
5.gather
由淺入深,先引用一個@江戶川柯壯 的例子
b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)
---------------------
作者:江戶川柯壯
來源:CSDN
原文:https://blog.csdn.net/edogawachia/article/details/80515038
輸出分別爲:
1 2 3
4 5 6
[torch.FloatTensor of size 2x3]
1 2
6 4
[torch.FloatTensor of size 2x2]
1 5 6
1 2 3
[torch.FloatTensor of size 2x3]
dim是最難理解的一個變量。這裏我的理解是,dim是index中定義序號的維度。dim=0即序號是沿着縱向數的,橫向同一行的數字序號一樣,dim=1則反過來。
index的維度也需要注意。首先index每次是要把其他維遍歷一遍的,即index中除了dim維,其他維的形狀要和原數組保持一致。輸入是[2,3], dim=0時,index可能的維數爲[N,3],即可以遍歷數組的列多次,但是每一次必須把所有列都遍歷一遍。其次,index的維數要和輸入數組一致,即使只遍歷一次,二維輸入對應的index也要是二維的[1,3],而不能是一維的3.
再舉一個3維的例子:
import torch
a = torch.Tensor([[[1,2],[3,4]],[[5,6],[7,8]],[[9,0],[11,12]]])
index = torch.LongTensor([[[1,0]],[[0,0]],[[1,1]]])
if __name__ == '__main__':
print(a.size())
print(index.size())
print(torch.gather(a,1,index))
輸出爲:
torch.Size([3, 2, 2])
torch.Size([3, 1, 2])
tensor([[[ 3., 2.]],
[[ 5., 6.]],
[[ 11., 12.]]])
熟記除了dim維其他維度必須與原矩陣保持一致,最終的效果就一目瞭然了。