pytorch中数据格式变换及创建掩码mask示例
常用维度转换方法
import torch
case = torch.arange(0, 6).view(2, 3)
print(case, case.size())
# tensor([[0, 1, 2],
# [3, 4, 5]]) torch.Size([2, 3])
-
permute()
''' 交换维度 ''' case_permute = case.permute(1, 0) print(case_permute, case_permute.size()) # tensor([[0, 3], # [1, 4], # [2, 5]]) torch.Size([3, 2])
-
view()
''' view()函数作用是将一个多行的Tensor,拼接成指定维度。 ''' case_view = case.view(3, 2) print(case_view, case_view.size()) # tensor([[0, 1], # [2, 3], # [4, 5]]) torch.Size([3, 2]) # 注意不是[[0, 3], ...],与permute()做区分! case_view = case.view(1, -1) print(case_view, case_view.size()) # tensor([[0, 1, 2, 3, 4, 5]]) torch.Size([1, 6])
-
squeeze()与unsqueeze()
''' squeeze中的参数0、1分别代表第一、第二维度,squeeze(0)表示如果第一维度值为1,则去掉,否则不变。 故case的维度(1,3),可去掉1成(3),但不可去掉3。 ''' case = torch.arange(0, 3).view(1, 3) print(case, case.size()) # tensor([[0, 1, 2]]) torch.Size([1, 3]) case_squeeze = case.squeeze(0) print(case_squeeze, case_squeeze.size()) # tensor([0, 1, 2]) torch.Size([3]) case_squeeze = case.squeeze(1) print(case_squeeze, case_squeeze.size()) # tensor([[0, 1, 2]]) torch.Size([1, 3]) ''' unsqueeze()与squeeze()作用相反。参数代表的意思相同。 ''' case = torch.arange(0, 3).view(3) print(case, case.size()) # tensor([0, 1, 2]) torch.Size([3]) case_unsqueeze = case.unsqueeze(0) print(case_unsqueeze, case_unsqueeze.size()) # tensor([[0, 1, 2]]) torch.Size([1, 3]) case_unsqueeze = case.unsqueeze(1) print(case_unsqueeze, case_unsqueeze.size()) # tensor([[0], # [1], # [2]]) torch.Size([3, 1])
-
expand()
''' 返回tensor的一个新视图,单个维度扩大为更大的尺寸。 tensor也可以扩大为更高维,新增加的维度将附在前面。 扩大tensor不需要分配新内存,只是仅仅新建一个tensor的视图,其中通过将stride设为0,一维将会扩展位更高维。任何一个一维的在不分配新内存情况下可扩展为任意的数值。 需要注意的是:使用expand()函数的时候,x自身不会改变,因此需要将结果重新赋值。 ''' x = torch.Tensor([[1], [2], [3]]) print("x.size():",x.size()) y=x.expand( 3,4 ) print("x.size():",x.size()) print("y.size():",y.size()) print(x) print(y) # x.size(): torch.Size([3, 1]) # x.size(): torch.Size([3, 1]) # y.size(): torch.Size([3, 4]) # tensor([[1.], # [2.], # [3.]]) # tensor([[1., 1., 1., 1.], # [2., 2., 2., 2.], # [3., 3., 3., 3.]])
示例:根据batch中句子长度lengths构建掩码mask
# sequence_length = torch.LongTensor([10,8,6,3,7]).cuda() # 假设batch_size为5的输入.转换至gpu上
sequence_length = torch.LongTensor([10,8,6,3,7]) # 假设batch_size为5的输入
batch_size = sequence_length.size(0) # 获得batch_size
max_len = sequence_length.data.max() # 获得最大长度
seq_range = torch.arange(0,max_len).long()
print(seq_range, seq_range.size())
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) torch.Size([10])
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
# seq_range_expand = seq_range_expand.cuda() # 转换至gpu上
print(seq_range_expand, seq_range_expand.size())
# tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]) torch.Size([5, 10])
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) # expand_as 函数
print(seq_length_expand, seq_length_expand.size())
# tensor([[10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
# [ 8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
# [ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
# [ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
# [ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]]) torch.Size([5, 10])
print(seq_range_expand < seq_length_expand)
# tensor([[ True, True, True, True, True, True, True, True, True, True],
# [ True, True, True, True, True, True, True, True, False, False],
# [ True, True, True, True, True, True, False, False, False, False],
# [ True, True, True, False, False, False, False, False, False, False],
# [ True, True, True, True, True, True, True, False, False, False]])
另一种创建mask的简单方法
def generate_sent_masks(self, batch_size, max_seq_length, source_lengths):
""" Generate sentence masks for encoder hidden states.
returns enc_masks (Tensor): Tensor of sentence masks of shape (b, max_seq_length),where max_seq_length = max source length """
enc_masks = torch.zeros(batch_size, max_seq_length, dtype=torch.float)
for e_id, src_len in enumerate(source_lengths):
enc_masks[e_id, :src_len] = 1
return enc_masks