PyTorch學習總結(四)——Utilities

這裏寫圖片描述


1. PackedSequence

torch.nn.utils.rnn.PackedSequence

這個類的實例不能手動創建。它們只能被pack_padded_sequence() 實例化。

2. pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence()**

輸入:

input: [seq_length x batch_size x input_size] 或 [batch_size x seq_length x input_size],input中的seq要按照長度遞減的方式排列。

lengths: seq的長度列表,是一個遞減的列表,與input裏的seq長度對應。ie. [5,4,1]

batch_first: bool變量,當它爲True時,表示input爲這種輸入形式[batch_size x seq_length x input_size],否則爲另一種。

輸出:

一個PackedSequence對象,包含一個Variable類型的data,和鏈表類型的batch_sizes。

batch的每一個元素,代表data中,多少行爲一個batch。

例如:

輸入爲

input
Variable containing:
(0 ,.,.) = 
  1
  2
  3

(1 ,.,.) = 
  1
  0
  0
[torch.FloatTensor of size 2x3x1]
lengths = [3, 1]

爲了實現壓縮編碼,即把填充去除。我們最終的輸出爲

PackedSequence(data=Variable containing:
 1
 1
 2
 3
[torch.FloatTensor of size 4x1]
, batch_sizes=[2, 1, 1])

這就表明,前兩個1屬於一個batch,後面兩個分別屬於不同的batch。換句話說,從batch_sizes可以看出,兩個seq的長度分別爲1,3。後面的module或function可以根據batch_sizes讀取對應的數據。

代碼詳解

這裏我們以上面的輸入爲例,研究該函數到底是怎麼實現數據壓縮的。

def pack_padded_sequence(input, lengths, batch_first=False):
    # juge the length is > 0
    if lengths[-1] <= 0:
        raise ValueError("length of all samples has to be greater than 0, "
                         "but found an element in 'lengths' that is <=0")
    # change the input into the shape of [seq_length x batch_size x input_size]
    # here input is [3, 2, 1]
    if batch_first:
        input = input.transpose(0, 1)

    steps = []
    batch_sizes = []
    # get the reversed iterator of the lengths
    lengths_iter = reversed(lengths)
    # here current_length == 1
    current_length = next(lengths_iter)
    batch_size = input.size(1)
    if len(lengths) != batch_size:
        raise ValueError("lengths array has incorrect size")
    # here 1 indicate the 'step' start from 1
    for step, step_value in enumerate(input, 1):
        """
        step_value == 1
                      1
                     [torch.FloatTensor of size 2x1]
        """
        steps.append(step_value[:batch_size])
        batch_sizes.append(batch_size)
        # juge if step to the end of a short seq
        while step == current_length:
            try:
                new_length = next(lengths_iter)
            except StopIteration:
                current_length = None
                break
            # check the lengths if is a decrasing list
            if current_length > new_length:  # remember that new_length is the preceding length in the array
                raise ValueError("lengths array has to be sorted in decreasing order")
            # already step over a short seq, so the number of the batch should minus 1.
            batch_size -= 1
            current_length = new_length
        if current_length is None:
            break
    # here concat the list along the dim0.
    return PackedSequence(torch.cat(steps), batch_sizes)

3. pad_packed_sequence

nn.utils.rnn.pad_packed_sequence()

這就是上一個函數的逆操作。輸入是一個PackedSequence對象,包含batch_sizes,可以根據其對其中的data進行解耦。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章