PyTorch学习总结(四)——Utilities

这里写图片描述


1. PackedSequence

torch.nn.utils.rnn.PackedSequence

这个类的实例不能手动创建。它们只能被pack_padded_sequence() 实例化。

2. pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence()**

输入:

input: [seq_length x batch_size x input_size] 或 [batch_size x seq_length x input_size],input中的seq要按照长度递减的方式排列。

lengths: seq的长度列表,是一个递减的列表,与input里的seq长度对应。ie. [5,4,1]

batch_first: bool变量,当它为True时,表示input为这种输入形式[batch_size x seq_length x input_size],否则为另一种。

输出:

一个PackedSequence对象,包含一个Variable类型的data,和链表类型的batch_sizes。

batch的每一个元素,代表data中,多少行为一个batch。

例如:

输入为

input
Variable containing:
(0 ,.,.) = 
  1
  2
  3

(1 ,.,.) = 
  1
  0
  0
[torch.FloatTensor of size 2x3x1]
lengths = [3, 1]

为了实现压缩编码,即把填充去除。我们最终的输出为

PackedSequence(data=Variable containing:
 1
 1
 2
 3
[torch.FloatTensor of size 4x1]
, batch_sizes=[2, 1, 1])

这就表明,前两个1属于一个batch,后面两个分别属于不同的batch。换句话说,从batch_sizes可以看出,两个seq的长度分别为1,3。后面的module或function可以根据batch_sizes读取对应的数据。

代码详解

这里我们以上面的输入为例,研究该函数到底是怎么实现数据压缩的。

def pack_padded_sequence(input, lengths, batch_first=False):
    # juge the length is > 0
    if lengths[-1] <= 0:
        raise ValueError("length of all samples has to be greater than 0, "
                         "but found an element in 'lengths' that is <=0")
    # change the input into the shape of [seq_length x batch_size x input_size]
    # here input is [3, 2, 1]
    if batch_first:
        input = input.transpose(0, 1)

    steps = []
    batch_sizes = []
    # get the reversed iterator of the lengths
    lengths_iter = reversed(lengths)
    # here current_length == 1
    current_length = next(lengths_iter)
    batch_size = input.size(1)
    if len(lengths) != batch_size:
        raise ValueError("lengths array has incorrect size")
    # here 1 indicate the 'step' start from 1
    for step, step_value in enumerate(input, 1):
        """
        step_value == 1
                      1
                     [torch.FloatTensor of size 2x1]
        """
        steps.append(step_value[:batch_size])
        batch_sizes.append(batch_size)
        # juge if step to the end of a short seq
        while step == current_length:
            try:
                new_length = next(lengths_iter)
            except StopIteration:
                current_length = None
                break
            # check the lengths if is a decrasing list
            if current_length > new_length:  # remember that new_length is the preceding length in the array
                raise ValueError("lengths array has to be sorted in decreasing order")
            # already step over a short seq, so the number of the batch should minus 1.
            batch_size -= 1
            current_length = new_length
        if current_length is None:
            break
    # here concat the list along the dim0.
    return PackedSequence(torch.cat(steps), batch_sizes)

3. pad_packed_sequence

nn.utils.rnn.pad_packed_sequence()

这就是上一个函数的逆操作。输入是一个PackedSequence对象,包含batch_sizes,可以根据其对其中的data进行解耦。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章