pytorch——unsqueeze與expand

torch中的unsqueeze()函數來增加一個維度,expand()函數以行或列來廣播。

# -*- encoding: utf-8 -*-
import torch

# 需求是對一個batch_size=2, seq_len=3的兩個序列進行mask的擴展,
# 擴展爲[batch_size, seq_len, 4, seq_len]
tokens = torch.tensor([[1,2, 3],[2,1,0]])
mask = tokens!=0
print(mask)
print(mask.shape)

print(mask.unsqueeze(2).shape)
print(mask.unsqueeze(2))
print(mask.unsqueeze(1).shape)
print(mask.unsqueeze(1))

multi = mask.unsqueeze(2)*mask.unsqueeze(1)
print('multi shape:',multi.shape) # [batch_size, seq, seq]
print(multi)

select = multi.unsqueeze(2)
print(select.shape) # batch, seq, 1, seq
print(select)
print(select.expand(-1,-1, 4, -1)) # expand的作用是把某個維度上爲1的擴展爲指定的個數
  • expand()在行或列上的擴展
b shape: torch.Size([3, 1])
bb shape: torch.Size([3, 3])
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])
c shape: torch.Size([1, 3])
cc shape: torch.Size([3, 3])
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

result:

b shape: torch.Size([3, 1])
bb shape: torch.Size([3, 3])
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])
c shape: torch.Size([1, 3])
cc shape: torch.Size([3, 3])
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章