pytorch使用教程-基於自定義 Dataloader中的collate_fn 函數 實現變長數據處理

問題背景

想要使用pytorch 框架中的 Dataset 和 Dataloader 類,將變長序列整合爲batch數據 (主要是對長短不一的序列進行補齊),通過自定義collate_fn函數,實現對變長數據的處理。

主要思路

Dataset 主要負責讀取單條數據,建立索引方式。
Dataloader 負責將數據聚合爲batch。

應用實例

測試環境: python 3.6 ,pytorch 1.2.0

數據路徑:
在這裏插入圖片描述
data路徑下存儲的是待存儲的數據樣本。
舉例:其中的 1.json 樣本格式爲:
在這裏插入圖片描述

定義數據集class,進行數據索引

數據集class定義代碼:

import os
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
class time_series_dataset(Dataset):
    def __init__(self, data_root):
        """
        :param data_root:   數據集路徑
        """
        self.data_root = data_root
        file_list = os.listdir(data_root)
        file_prefix = []
        for file in file_list:
            if '.json' in file:
                file_prefix.append(file.split('.')[0])
        file_prefix = list(set(file_prefix))
        self.data = file_prefix
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        prefix = self.data[index]
        import json
        with open(self.data_root+prefix+'.json','r',encoding='utf-8') as f:
            data_dic=json.load(f)
        feature = np.array(data_dic['feature'])
        length=len(data_dic['feature'])
        feature = torch.from_numpy(feature)
        label = np.array(data_dic['label'])
        label = torch.from_numpy(label)
        sample = {'feature': feature, 'label': label, 'id': prefix,'length':length}
        return sample

數據集實例化:

dataset = time_series_dataset("./data/") # "./data/" 爲數據集文件存儲路徑

基於此數據集的實際數據格式如下:
舉例: dataset[0]

{'feature': tensor([17, 14, 16, 18, 14, 16], dtype=torch.int32),
 'label': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
         0], dtype=torch.int32),
 'id': '2',
 'length': 6}

定義collate_fn函數,傳入Dataloader類

自定義collate_fn代碼

def collate_func(batch_dic):
    from torch.nn.utils.rnn import pad_sequence
    batch_len=len(batch_dic)
    max_seq_length=max([dic['length'] for dic in batch_dic])
    mask_batch=torch.zeros((batch_len,max_seq_length))
    fea_batch=[]
    label_batch=[]
    id_batch=[]
    for i in range(len(batch_dic)):
        dic=batch_dic[i]
        fea_batch.append(dic['feature'])
        label_batch.append(dic['label'])
        id_batch.append(dic['id'])
        mask_batch[i,:dic['length']]=1
    res={}
    res['feature']=pad_sequence(fea_batch,batch_first=True)
    res['label']=pad_sequence(label_batch,batch_first=True)
    res['id']=id_batch
    res['mask']=mask_batch
    return res

說明: mask 字段用以存儲變長序列的實際長度,補零的部分記爲0,實際序列對應位置記爲1。返回數據的格式及包含的字段,根據自己的需求進行定義。
Dataloader實例化調用代碼:

train_loader = DataLoader(dataset, batch_size=3, num_workers=1, shuffle=True,collate_fn=collate_func)

完整流程代碼

import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
class time_series_dataset(Dataset):
    def __init__(self, data_root):
        """
        :param data_root:   數據集路徑
        """
        self.data_root = data_root
        file_list = os.listdir(data_root)
        file_prefix = []
        for file in file_list:
            if '.json' in file:
                file_prefix.append(file.split('.')[0])
        file_prefix = list(set(file_prefix))
        self.data = file_prefix
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        prefix = self.data[index]
        import json
        with open(self.data_root+prefix+'.json','r',encoding='utf-8') as f:
            data_dic=json.load(f)
        feature = np.array(data_dic['feature'])
        length=len(data_dic['feature'])
        feature = torch.from_numpy(feature)
        label = np.array(data_dic['label'])
        label = torch.from_numpy(label)
        sample = {'feature': feature, 'label': label, 'id': prefix,'length':length}
        return sample
def collate_func(batch_dic):
    from torch.nn.utils.rnn import pad_sequence
    batch_len=len(batch_dic)
    max_seq_length=max([dic['length'] for dic in batch_dic])
    mask_batch=torch.zeros((batch_len,max_seq_length))
    fea_batch=[]
    label_batch=[]
    id_batch=[]
    for i in range(len(batch_dic)):
        dic=batch_dic[i]
        fea_batch.append(dic['feature'])
        label_batch.append(dic['label'])
        id_batch.append(dic['id'])
        mask_batch[i,:dic['length']]=1
    res={}
    res['feature']=pad_sequence(fea_batch,batch_first=True)
    res['label']=pad_sequence(label_batch,batch_first=True)
    res['id']=id_batch
    res['mask']=mask_batch
    return res
if __name__ == "__main__":
    dataset = time_series_dataset("./data/")
    batch_size=3
    train_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=True,collate_fn=collate_func)
    for batch_idx, batch in tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / batch_size) + 1):
        inputs,labels,masks,ids=batch['feature'],batch['label'],batch['mask'],batch['id']
        break

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章