Pytorch: dataloader的一些使用心得

Pytorch: Dataloader的一些使用心得

這篇博文不講原理,只講一些使用方法和技巧。所有提供的信息僅供參考,不要當作金科玉律。

基本程序框架

首先給出講述的時候使用的基本程序框架。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

class My_Dataset(Dataset):

    def __init__(self, list1, array2):
        self.len = len(list1)
        self.x_data = list1 # something support indexing, like a list, length = 16
        self.y_data = array2 # something support indexing, like torch.Tensor, shape = (16, 4, 5)

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

# padding unequal length sequences

def collate_fn(batch_data):
    return batch_data

# train dataloader & test dataloader

list1 = [chr(ord('a') + i) for i in range(16)] # 'a'~'p'
array2 = torch.randn((16, 4, 5))

my_dataset = My_Dataset(list1, array2)
my_dataloader = DataLoader(dataset = my_dataset,
                           batch_size = 4,
                           collate_fn = collate_fn)

從dataloader獲取數據

注意這個函數:

def __getitem__(self, index):
    return self.x_data[index], self.y_data[index]

這代表,如果你用下標索引i從dataloader中取出值,返回值將會是一個長度爲2的元組,下標爲0的是list1[i](即第i+1個字母),下標爲1的是array2[i](即一個size = (4, 5)的tensor)。暫且稱這種形式的數據爲data[i]

此時如果你運行如下指令:

for batch_data in enumerate(my_dataloader):
    # show batch_data

batch_data是一個長度爲2的元組,下標爲0的是這個batch的序號(在以上的程序裏面是0~3),下標爲1的是一個長度爲4(batch_size)的support indexing的對象,這個對象的每個元素就是對應batch中應該包含的幾個data[i],比如第0個batch的這個列表中的元素就分別是data[0],..data[3]。至於data[i]則是剛纔說的由兩項數據所構成的元組。
在這裏,下標爲1的對象是一個列表。而如果數據本身就是一個tensor的話,這裏會給一個第一維維度爲batch_size,其他維維度數對應的tensor.

此時如果你運行如下指令:

for batch_index, batch_data in enumerate(train_loader):
    # show data

這裏的batch_index對應元組的下標爲0的元素,即這個batch的序號(在以上的程序裏面是0~3);batch_data對應上面的列表(support indexing的對象)。顯然這種更細緻的處理是更常用的。

對於以上講的兩點,讀者可以直接跑一下附錄1所示的程序來獲得直觀感受。

collate_fn的使用

在從dataloader中讀取數據時,可以通過collate_fn做處理,使讀取的數據符合要求。

讓我們審視這個函數:

def collate_fn(batch_data):
    return batch_data

這裏輸入的batch_data就是上一節那個以batch_size爲長度,以對應位置的data[i]爲元素的列表。如果要取得元素之後進行特定處理,可以在這個函數裏面操作;這個函數的返回值會代替原來那個列表的位置。可以運行附錄2的代碼獲得直觀感受。

collate_fn的使用實例

在自然語言處理中,可能要把不等長的tensor padding 成等長,這個步驟可以在collate_fn裏面做。舉個例子,下面的這個函數從不等長Tensor的列表生成一個padding成等長的高維tensor.

def collate_fn(data):
    # self.data: list of tensors of different length
    # data:[x[0], x[1], ..], x[0].shape = (20, 128), x[1].shape = (30, 128)
    #                        x[2].shape = (28, 128), x[3].shape = (25, 128)
    data.sort(key=lambda data: len(data[0]), reverse=True) # 按照序列長度降序排列
    seq_len_list = [elem.shape[0] for elem in data]
    data = pad_sequence(data, batch_first=True, padding_value=0)
    seq_len_list = torch.Tensor(seq_len_list)
    return data_batch, seq_len_list
# data_batch.shape = [4, 30, 128], seq_len_list = [20, 30, 28, 25]

函數的返回值包括合併的高維tensor和每個小tensor的實際長度,方便後續處理使用。

附錄

附錄1

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

torch.manual_seed(314)

class My_Dataset(Dataset):

    def __init__(self, list1, array2):
        self.len = len(list1)
        self.x_data = list1 # something support indexing, like a list, length = 16
        self.y_data = array2 # something support indexing, like torch.Tensor, shape = (16, 4, 5)

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

# padding unequal length sequences

def collate_fn(batch_data):
    return batch_data

# train dataloader & test dataloader

list1 = [chr(ord('a') + i) for i in range(16)] # 'a'~'p'
array2 = torch.randn((16, 4, 5))

my_dataset = My_Dataset(list1, array2)
my_dataloader = DataLoader(dataset = my_dataset,
                           batch_size = 4,
                           collate_fn = collate_fn)


for batch_data in enumerate(my_dataloader):
    # show batch_data
    print("New Batch")
    print(type(batch_data), len(batch_data), batch_data[0], type(batch_data[1]))
    print(len(batch_data[1]), type(batch_data[1][0]))
    print(batch_data[1][0][0], type(batch_data[1][0][1]), batch_data[1][0][1].shape)

for batch_index, batch_data in enumerate(my_dataloader):
    # show batch_data
    print("Batch", batch_index)
    for i in range(len(batch_data)):
        print(type(batch_data[i]), len(batch_data[i]))
        print(batch_data[i][0], type(batch_data[i][1]), batch_data[i][1].shape)

附錄2

...

my_dataset = My_Dataset(list1, array2)
my_dataloader = DataLoader(dataset = my_dataset,
                           batch_size = 4,
                           collate_fn = collate_fn)

for batch_index, batch_data in enumerate(my_dataloader):
    # show batch_data
    print("Batch", batch_index)
    print(batch_data)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章