LSTM 的 seq_length 是怎樣決定的(Pytorch)?

PyTorch中實現LSTM是十分方便的,只需要定義輸入維度,隱層維度,num_layers,以及分類個數就可以了。

 

單層LSTM:

此結構包含3個LSTM單元,seq_len=3

 

 

兩層LSTM: 

第一層的3個時間步的多維隱藏輸出作爲第二層的3個時間步的輸入.

並且初始h0((2 * num_directions, batch, hidden_size))默認爲0初始化。

 

 

For More 

 

Code

# -*- coding: utf-8 -*-
 
import torch
import torch.utils.data as Data
import torch.nn as nn
import torchvision.transforms as transforms
import numpy as np
###   Demo dataset
 
data_ = [[1, 10, 11, 15, 9, 100],
         [2, 11, 12, 16, 9, 100],
         [3, 12, 13, 17, 9, 100],
         [4, 13, 14, 18, 9, 100],
         [5, 14, 15, 19, 9, 100],
         [6, 15, 16, 10, 9, 100],
         [7, 15, 16, 10, 9, 100],
         [8, 15, 16, 10, 9, 100],
         [9, 15, 16, 10, 9, 100],
         [10, 15, 16, 10, 9, 100]]
 
 
###   Demo Dataset class
 
class DemoDatasetLSTM(Data.Dataset):
 
    """
        Support class for the loading and batching of sequences of samples
        Args:
            dataset (Tensor): Tensor containing all the samples
            sequence_length (int): length of the analyzed sequence by the LSTM
            transforms (object torchvision.transform): Pytorch's transforms used to process the data
    """
 
    ##  Constructor
    def __init__(self, dataset, sequence_length=1, transforms=None):
        self.dataset = dataset
        self.seq_len = sequence_length
        self.transforms = transforms
 
    ##  Override total dataset's length getter
    def __len__(self):
        return self.dataset.__len__()
 
    ##  Override single items' getter
    def __getitem__(self, idx):
        if idx + self.seq_len > self.__len__():
            if self.transforms is not None:
                item = torch.zeros(self.seq_len, self.dataset[0].__len__())
                item[:self.__len__()-idx] = self.transforms(self.dataset[idx:])
                return item, item
            else:
                item = []
                item[:self.__len__()-idx] = self.dataset[idx:]
                return item, item
        else:
            if self.transforms is not None:
                return self.transforms(self.dataset[idx:idx+self.seq_len]), self.transforms(self.dataset[idx:idx+self.seq_len])
            else:
                return self.dataset[idx:idx+self.seq_len], self.dataset[idx:idx+self.seq_len]
 
 
###   Helper for transforming the data from a list to Tensor
 
def listToTensor(list):
    tensor = torch.empty(list.__len__(), list[0].__len__())
    for i in range(list.__len__()):
        tensor[i, :] = torch.FloatTensor(list[i])
    return tensor
 
###   Dataloader instantiation
 
# Parameters
seq_len = 3
batch_size = 2
data_transform = transforms.Lambda(lambda x: listToTensor(x))
 
dataset = DemoDatasetLSTM(data_, seq_len, transforms=data_transform)
data_loader = Data.DataLoader(dataset, batch_size, shuffle=False)
 
for data in data_loader:
    x, _ = data
    print(x)
    print('\n')

 

 

 

Reference

1  RNN之多層LSTM理解:輸入,輸出,時間步,隱藏節點數,層數

2  Pytorch中如何理解RNN LSTM GRU的input(重點理解seq_len / time_steps)

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章