LSTM 的 seq_length 是怎样决定的(Pytorch)?

PyTorch中实现LSTM是十分方便的,只需要定义输入维度,隐层维度,num_layers,以及分类个数就可以了。

 

单层LSTM:

此结构包含3个LSTM单元,seq_len=3

 

 

两层LSTM: 

第一层的3个时间步的多维隐藏输出作为第二层的3个时间步的输入.

并且初始h0((2 * num_directions, batch, hidden_size))默认为0初始化。

 

 

For More 

 

Code

# -*- coding: utf-8 -*-
 
import torch
import torch.utils.data as Data
import torch.nn as nn
import torchvision.transforms as transforms
import numpy as np
###   Demo dataset
 
data_ = [[1, 10, 11, 15, 9, 100],
         [2, 11, 12, 16, 9, 100],
         [3, 12, 13, 17, 9, 100],
         [4, 13, 14, 18, 9, 100],
         [5, 14, 15, 19, 9, 100],
         [6, 15, 16, 10, 9, 100],
         [7, 15, 16, 10, 9, 100],
         [8, 15, 16, 10, 9, 100],
         [9, 15, 16, 10, 9, 100],
         [10, 15, 16, 10, 9, 100]]
 
 
###   Demo Dataset class
 
class DemoDatasetLSTM(Data.Dataset):
 
    """
        Support class for the loading and batching of sequences of samples
        Args:
            dataset (Tensor): Tensor containing all the samples
            sequence_length (int): length of the analyzed sequence by the LSTM
            transforms (object torchvision.transform): Pytorch's transforms used to process the data
    """
 
    ##  Constructor
    def __init__(self, dataset, sequence_length=1, transforms=None):
        self.dataset = dataset
        self.seq_len = sequence_length
        self.transforms = transforms
 
    ##  Override total dataset's length getter
    def __len__(self):
        return self.dataset.__len__()
 
    ##  Override single items' getter
    def __getitem__(self, idx):
        if idx + self.seq_len > self.__len__():
            if self.transforms is not None:
                item = torch.zeros(self.seq_len, self.dataset[0].__len__())
                item[:self.__len__()-idx] = self.transforms(self.dataset[idx:])
                return item, item
            else:
                item = []
                item[:self.__len__()-idx] = self.dataset[idx:]
                return item, item
        else:
            if self.transforms is not None:
                return self.transforms(self.dataset[idx:idx+self.seq_len]), self.transforms(self.dataset[idx:idx+self.seq_len])
            else:
                return self.dataset[idx:idx+self.seq_len], self.dataset[idx:idx+self.seq_len]
 
 
###   Helper for transforming the data from a list to Tensor
 
def listToTensor(list):
    tensor = torch.empty(list.__len__(), list[0].__len__())
    for i in range(list.__len__()):
        tensor[i, :] = torch.FloatTensor(list[i])
    return tensor
 
###   Dataloader instantiation
 
# Parameters
seq_len = 3
batch_size = 2
data_transform = transforms.Lambda(lambda x: listToTensor(x))
 
dataset = DemoDatasetLSTM(data_, seq_len, transforms=data_transform)
data_loader = Data.DataLoader(dataset, batch_size, shuffle=False)
 
for data in data_loader:
    x, _ = data
    print(x)
    print('\n')

 

 

 

Reference

1  RNN之多层LSTM理解:输入,输出,时间步,隐藏节点数,层数

2  Pytorch中如何理解RNN LSTM GRU的input(重点理解seq_len / time_steps)

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章