pytorch中數據格式變換及創建掩碼mask示例

pytorch中數據格式變換及創建掩碼mask示例

常用維度轉換方法

import torch
case = torch.arange(0, 6).view(2, 3)
print(case, case.size())
# tensor([[0, 1, 2],
#         [3, 4, 5]]) torch.Size([2, 3])
  • permute()

    '''
    交換維度
    '''
    case_permute =  case.permute(1, 0)
    print(case_permute, case_permute.size())
    # tensor([[0, 3],
    #         [1, 4],
    #         [2, 5]]) torch.Size([3, 2])
    
  • view()

    '''
    view()函數作用是將一個多行的Tensor,拼接成指定維度。
    '''
    case_view = case.view(3, 2)
    print(case_view, case_view.size())
    # tensor([[0, 1],
    #         [2, 3],
    #         [4, 5]]) torch.Size([3, 2])
    # 注意不是[[0, 3], ...],與permute()做區分!
    
    case_view = case.view(1, -1)
    print(case_view, case_view.size())
    # tensor([[0, 1, 2, 3, 4, 5]]) torch.Size([1, 6])
    
  • squeeze()與unsqueeze()

    '''
    squeeze中的參數0、1分別代表第一、第二維度,squeeze(0)表示如果第一維度值爲1,則去掉,否則不變。
    故case的維度(1,3),可去掉1成(3),但不可去掉3。
    '''
    case = torch.arange(0, 3).view(1, 3)
    print(case, case.size())
    # tensor([[0, 1, 2]]) torch.Size([1, 3])
    
    case_squeeze = case.squeeze(0)
    print(case_squeeze, case_squeeze.size())
    # tensor([0, 1, 2]) torch.Size([3])
    
    case_squeeze = case.squeeze(1)
    print(case_squeeze, case_squeeze.size())
    # tensor([[0, 1, 2]]) torch.Size([1, 3])
    
    '''
    unsqueeze()與squeeze()作用相反。參數代表的意思相同。
    '''
    case = torch.arange(0, 3).view(3)
    print(case, case.size())
    # tensor([0, 1, 2]) torch.Size([3])
    
    case_unsqueeze = case.unsqueeze(0)
    print(case_unsqueeze, case_unsqueeze.size())
    # tensor([[0, 1, 2]]) torch.Size([1, 3])
    
    case_unsqueeze = case.unsqueeze(1)
    print(case_unsqueeze, case_unsqueeze.size())
    # tensor([[0],
    #         [1],
    #         [2]]) torch.Size([3, 1])
    
  • expand()

    '''
    返回tensor的一個新視圖,單個維度擴大爲更大的尺寸。 tensor也可以擴大爲更高維,新增加的維度將附在前面。 擴大tensor不需要分配新內存,只是僅僅新建一個tensor的視圖,其中通過將stride設爲0,一維將會擴展位更高維。任何一個一維的在不分配新內存情況下可擴展爲任意的數值。
    需要注意的是:使用expand()函數的時候,x自身不會改變,因此需要將結果重新賦值。
    '''
    x = torch.Tensor([[1], 
                      [2], 
                      [3]])
    print("x.size():",x.size())
    
    y=x.expand( 3,4 )
    print("x.size():",x.size())
    print("y.size():",y.size())
    print(x)
    print(y)
    # x.size(): torch.Size([3, 1])
    # x.size(): torch.Size([3, 1])
    # y.size(): torch.Size([3, 4])
    # tensor([[1.],
    #         [2.],
    #         [3.]])
    # tensor([[1., 1., 1., 1.],
    #         [2., 2., 2., 2.],
    #         [3., 3., 3., 3.]])
    

示例:根據batch中句子長度lengths構建掩碼mask

# sequence_length = torch.LongTensor([10,8,6,3,7]).cuda()  # 假設batch_size爲5的輸入.轉換至gpu上
sequence_length = torch.LongTensor([10,8,6,3,7])  # 假設batch_size爲5的輸入
batch_size = sequence_length.size(0)                     # 獲得batch_size
max_len = sequence_length.data.max()					 # 獲得最大長度
seq_range = torch.arange(0,max_len).long()
print(seq_range, seq_range.size())
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) torch.Size([10])

seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
# seq_range_expand = seq_range_expand.cuda()  # 轉換至gpu上
print(seq_range_expand, seq_range_expand.size())
# tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
#         [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
#         [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
#         [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
#         [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]) torch.Size([5, 10])

seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) # expand_as 函數
print(seq_length_expand, seq_length_expand.size())
# tensor([[10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
#         [ 8,  8,  8,  8,  8,  8,  8,  8,  8,  8],
#         [ 6,  6,  6,  6,  6,  6,  6,  6,  6,  6],
#         [ 3,  3,  3,  3,  3,  3,  3,  3,  3,  3],
#         [ 7,  7,  7,  7,  7,  7,  7,  7,  7,  7]]) torch.Size([5, 10])

print(seq_range_expand < seq_length_expand)
# tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
#         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
#         [ True,  True,  True,  True,  True,  True, False, False, False, False],
#         [ True,  True,  True, False, False, False, False, False, False, False],
#         [ True,  True,  True,  True,  True,  True,  True, False, False, False]])

另一種創建mask的簡單方法

def generate_sent_masks(self, batch_size, max_seq_length, source_lengths):
    """ Generate sentence masks for encoder hidden states.
        returns enc_masks (Tensor): Tensor of sentence masks of shape (b, max_seq_length),where max_seq_length = max source length """
    enc_masks = torch.zeros(batch_size, max_seq_length, dtype=torch.float)
    for e_id, src_len in enumerate(source_lengths):
        enc_masks[e_id, :src_len] = 1
    return enc_masks
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章