Pytorch中RNN LSTM的input(重點理解batch_size/time_steps)

原文鏈接:Pytorch中如何理解RNN LSTM的input(重點理解seq_len/time_steps) - 阿矛布朗斯洛特的文章 - 知乎

在建立時序模型時,若使用keras,我們在Input的時候就會在shape內設置好sequence_length(後面均用seq_len表示),接着便可以在自定義的data_generator內進行個性化的使用。這個值同時也就是time_steps,它代表了RNN內部的cell的數量,有點懵的朋友可以再去看看RNN的相關內容:

CSDN-專業IT技術社區-登錄​blog.csdn.net

所以設定好這個值是很重要的事情,它和batch_size,feature_dimensions(在詞向量的時候就是embedding_size了)構成了我們Input的三大維度,無論是keras/tensorflow,亦或是Pytorch,本質上都是這樣。

牽涉到這個問題是聽說Pytorch自由度更高,最近在做實驗的時候開始嘗試用Pytorch了,寫完代碼跑通後,過了段時間才意識到,好像沒有用到seq_len這個參數,果然是Keras用多了的後遺症?(果然是博主比較蠢!)檢查了一下才發現,DataLoader生成數據的時候,默認生成爲(batch_size, 1, feature_dims)。(這裏無視了batch_size和seq_len的順序,在建立模型的時候,比如nn.LSTM有個batch_first的參數,它決定了誰前誰後,但這不是我們這裏討論的重點)。

所以我們的seq_len/time_steps被默認成了1,這是在使用Pytorch的時候容易發生的問題,由於Keras先天的接口設置在Input時就讓我們無腦設置seq_len,這反而不會成爲我們在使用Keras時發生的問題,而Pytorch沒有讓我們在哪裏設置這個參數,所以一不小心可能就忽視了。

好了,接下來就來找找問題怎麼出現的,又怎麼解決。果然問題還是出現在了DataLoader,在__getitem__(self, index)這裏,決定了我們如何取出數據,在這裏我發現我自己還是一條一條取的。

    def __getitem__(self, idx):
        return self.input[idx], self.target[idx]

完全沒有意識到Torch需要在這裏進行seq_len的修飾,接下來該怎麼解決呢,首先看看我們希望的“取數據方式”。

假如我們有id = 1,2,3,4,5,6,7,8,9,10一共10個sample。

假設我們設定seq_len是3。

那現在數據的形式應該爲1-2-3,2-3-4,3-4-5,4-5-6,5-6-7,6-7-8,7-8-9,8-9-10,9-10-0,10-0-0(最後兩個數據不完整,進行補零)的10個數據。這是我們真正有了seq_len這個參數,帶有“循環”這個概念,要放進RNN等序列模型中進行處理的數據。所以之前說seq_len被我默認弄成了1,那就是把1,2,3,4,5,6,7,8,9,10這樣形式的10個數據分別放進了模型訓練,自然在DataLoader裏取數據的size就成了(batch_size, 1, feature_dims),而我們現在取數據纔會是(batch_size, 3, feature_dims)。

假設我們設定batch_size爲2。

那我們取出第一個batch爲1-2-3,2-3-4。這個batch的size就是(2,3,feature_dims)了。我們把這個玩意兒喂進模型。

接下來第二個batch爲3-4-5,4-5-6。

第三個batch爲5-6-7,6-7-8。

第四個batch爲7-8-9,8-9-10。

第五個batch爲9-10-0,10-0-0。我們的數據一共生成了5個batch。

可以看到,num_batch = num_samples / batch_size(這裏沒有進行向上或向下取整是因爲在某些地方可以設置是否需要那些不完整的被進行補零的batch),seq_len仍然不會影響最後生成的batch的數量,只有batch_size和num_samples會對batch的數量進行影響。

可能忽略了feature_dims僅憑藉id來代表數據難以理解,那換種方式看看,假如feature_dims爲6:

data_ = [[1, 10, 11, 15, 9, 100],
         [2, 11, 12, 16, 9, 100],
         [3, 12, 13, 17, 9, 100],
         [4, 13, 14, 18, 9, 100],
         [5, 14, 15, 19, 9, 100],
         [6, 15, 16, 10, 9, 100],
         [7, 15, 16, 10, 9, 100],
         [8, 15, 16, 10, 9, 100],
         [9, 15, 16, 10, 9, 100],
         [10, 15, 16, 10, 9, 100]]

仍然設置seq_len爲3,batch_size爲2。

這時我們的第一個batch爲

tensor([[[  1.,  10.,  11.,  15.,   9., 100.],
         [  2.,  11.,  12.,  16.,   9., 100.],
         [  3.,  12.,  13.,  17.,   9., 100.]],

        [[  2.,  11.,  12.,  16.,   9., 100.],
         [  3.,  12.,  13.,  17.,   9., 100.],
         [  4.,  13.,  14.,  18.,   9., 100.]]])

這就是剛剛的1-2-3,2-3-4嘛。

而最後一個batch爲

tensor([[[  9.,  15.,  16.,  10.,   9., 100.],
         [ 10.,  15.,  16.,  10.,   9., 100.],
         [  0.,   0.,   0.,   0.,   0.,   0.]],

        [[ 10.,  15.,  16.,  10.,   9., 100.],
         [  0.,   0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,   0.,   0.]]])

最後放上Demo,由於每個人的數據甚至loss等等都不一樣,不過大家應該能夠從Demo中得到一些如何針對自己的Project進行修改的點子。

# -*- coding: utf-8 -*-

import torch
import torch.utils.data as Data
import torch.nn as nn
import torchvision.transforms as transforms
import numpy as np
###   Demo dataset

data_ = [[1, 10, 11, 15, 9, 100],
         [2, 11, 12, 16, 9, 100],
         [3, 12, 13, 17, 9, 100],
         [4, 13, 14, 18, 9, 100],
         [5, 14, 15, 19, 9, 100],
         [6, 15, 16, 10, 9, 100],
         [7, 15, 16, 10, 9, 100],
         [8, 15, 16, 10, 9, 100],
         [9, 15, 16, 10, 9, 100],
         [10, 15, 16, 10, 9, 100]]


###   Demo Dataset class

class DemoDatasetLSTM(Data.Dataset):

    """
        Support class for the loading and batching of sequences of samples

        Args:
            dataset (Tensor): Tensor containing all the samples
            sequence_length (int): length of the analyzed sequence by the LSTM
            transforms (object torchvision.transform): Pytorch's transforms used to process the data
    """

    ##  Constructor
    def __init__(self, dataset, sequence_length=1, transforms=None):
        self.dataset = dataset
        self.seq_len = sequence_length
        self.transforms = transforms

    ##  Override total dataset's length getter
    def __len__(self):
        return self.dataset.__len__()

    ##  Override single items' getter
    def __getitem__(self, idx):
        if idx + self.seq_len > self.__len__():
            if self.transforms is not None:
                item = torch.zeros(self.seq_len, self.dataset[0].__len__())
                item[:self.__len__()-idx] = self.transforms(self.dataset[idx:])
                return item, item
            else:
                item = []
                item[:self.__len__()-idx] = self.dataset[idx:]
                return item, item
        else:
            if self.transforms is not None:
                return self.transforms(self.dataset[idx:idx+self.seq_len]), self.transforms(self.dataset[idx:idx+self.seq_len])
            else:
                return self.dataset[idx:idx+self.seq_len], self.dataset[idx:idx+self.seq_len]


###   Helper for transforming the data from a list to Tensor

def listToTensor(list):
    tensor = torch.empty(list.__len__(), list[0].__len__())
    for i in range(list.__len__()):
        tensor[i, :] = torch.FloatTensor(list[i])
    return tensor

###   Dataloader instantiation

# Parameters
seq_len = 3
batch_size = 2
data_transform = transforms.Lambda(lambda x: listToTensor(x))

dataset = DemoDatasetLSTM(data_, seq_len, transforms=data_transform)
data_loader = Data.DataLoader(dataset, batch_size, shuffle=False)

for data in data_loader:
    x, _ = data
    print(x)
    print('\n')
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章