1. PackedSequence
torch.nn.utils.rnn.PackedSequence
這個類的實例不能手動創建。它們只能被pack_padded_sequence() 實例化。
2. pack_padded_sequence
torch.nn.utils.rnn.pack_padded_sequence()**
輸入:
input: [seq_length x batch_size x input_size] 或 [batch_size x seq_length x input_size],input中的seq要按照長度遞減的方式排列。
lengths: seq的長度列表,是一個遞減的列表,與input裏的seq長度對應。ie. [5,4,1]
batch_first: bool變量,當它爲True時,表示input爲這種輸入形式[batch_size x seq_length x input_size],否則爲另一種。
輸出:
一個PackedSequence對象,包含一個Variable類型的data,和鏈表類型的batch_sizes。
batch的每一個元素,代表data中,多少行爲一個batch。
例如:
輸入爲
input
Variable containing:
(0 ,.,.) =
1
2
3
(1 ,.,.) =
1
0
0
[torch.FloatTensor of size 2x3x1]
lengths = [3, 1]
爲了實現壓縮編碼,即把填充去除。我們最終的輸出爲
PackedSequence(data=Variable containing:
1
1
2
3
[torch.FloatTensor of size 4x1]
, batch_sizes=[2, 1, 1])
這就表明,前兩個1屬於一個batch,後面兩個分別屬於不同的batch。換句話說,從batch_sizes可以看出,兩個seq的長度分別爲1,3。後面的module或function可以根據batch_sizes讀取對應的數據。
代碼詳解
這裏我們以上面的輸入爲例,研究該函數到底是怎麼實現數據壓縮的。
def pack_padded_sequence(input, lengths, batch_first=False):
# juge the length is > 0
if lengths[-1] <= 0:
raise ValueError("length of all samples has to be greater than 0, "
"but found an element in 'lengths' that is <=0")
# change the input into the shape of [seq_length x batch_size x input_size]
# here input is [3, 2, 1]
if batch_first:
input = input.transpose(0, 1)
steps = []
batch_sizes = []
# get the reversed iterator of the lengths
lengths_iter = reversed(lengths)
# here current_length == 1
current_length = next(lengths_iter)
batch_size = input.size(1)
if len(lengths) != batch_size:
raise ValueError("lengths array has incorrect size")
# here 1 indicate the 'step' start from 1
for step, step_value in enumerate(input, 1):
"""
step_value == 1
1
[torch.FloatTensor of size 2x1]
"""
steps.append(step_value[:batch_size])
batch_sizes.append(batch_size)
# juge if step to the end of a short seq
while step == current_length:
try:
new_length = next(lengths_iter)
except StopIteration:
current_length = None
break
# check the lengths if is a decrasing list
if current_length > new_length: # remember that new_length is the preceding length in the array
raise ValueError("lengths array has to be sorted in decreasing order")
# already step over a short seq, so the number of the batch should minus 1.
batch_size -= 1
current_length = new_length
if current_length is None:
break
# here concat the list along the dim0.
return PackedSequence(torch.cat(steps), batch_sizes)
3. pad_packed_sequence
nn.utils.rnn.pad_packed_sequence()
這就是上一個函數的逆操作。輸入是一個PackedSequence對象,包含batch_sizes,可以根據其對其中的data進行解耦。