torch中的unsqueeze()
函數來增加一個維度,expand()
函數以行或列來廣播。
# -*- encoding: utf-8 -*-
import torch
# 需求是對一個batch_size=2, seq_len=3的兩個序列進行mask的擴展,
# 擴展爲[batch_size, seq_len, 4, seq_len]
tokens = torch.tensor([[1,2, 3],[2,1,0]])
mask = tokens!=0
print(mask)
print(mask.shape)
print(mask.unsqueeze(2).shape)
print(mask.unsqueeze(2))
print(mask.unsqueeze(1).shape)
print(mask.unsqueeze(1))
multi = mask.unsqueeze(2)*mask.unsqueeze(1)
print('multi shape:',multi.shape) # [batch_size, seq, seq]
print(multi)
select = multi.unsqueeze(2)
print(select.shape) # batch, seq, 1, seq
print(select)
print(select.expand(-1,-1, 4, -1)) # expand的作用是把某個維度上爲1的擴展爲指定的個數
expand()
在行或列上的擴展
b shape: torch.Size([3, 1])
bb shape: torch.Size([3, 3])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
c shape: torch.Size([1, 3])
cc shape: torch.Size([3, 3])
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
result:
b shape: torch.Size([3, 1])
bb shape: torch.Size([3, 3])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
c shape: torch.Size([1, 3])
cc shape: torch.Size([3, 3])
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])