Pytorch——DataSet與DataLoader

在使用 pytorch 構建深度學習相關的項目時,通常需要經過【模型結構】-【損失函數定義】-【數據設置】-【訓練代碼】-【log、驗證、可視化與 checkpoints】。其中,【數據設置】往往因爲項目/任務的不同,需要自定義合適的DataLoader(數據加載器)。

本文即將介紹 torch.utils.data 中的 Dataset 與 Dataloader 的基本用法,以 Unpaired Image-to-Image Translation 任務的非成對圖像數據的加載爲例,講解 pytorch 如何自定義數據加載器。

下面的代碼均在文件 dataset.py 中。

(一)引入必須的包

# -*- coding:utf-8 -*-

import torch.utils.data as data
import torchvision.transforms as transforms
import os
from PIL import Image
import random
import torch
import numpy as np

(二)自定義數據集 Dataset

#### 01. Create a dataset
## BaseDataset
class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return "BaseDataset"

    def initialize(self, opt):
        self.opt = opt
'''
定義一些公用的屬性/函數;一般的,torch.utils.data.Dataset 本身已經包含了很多屬性,如 __len__, __getitem__ 等。

一般我們會新增一個成員函數 name 和 initialize,分別用於:
1)name:沒有任何意義,純屬裝 B
2)在 pytorch 中,我們經常會使用到 parser,即一個能夠從命令行賦予超參數值的輔助類,我們在代碼中實例化它的一個對象爲 "opt" ,而且,諸如 opt.img_size, opt.batch_size 這樣的參數是與 data 相關的,所以我們通常會在這個函數引入 opt,並將它作爲自己一個屬性 self.opt,如此,我們就可以隨時訪問所有的超參數了。
'''

下面我們要自定義數據集 UnAlignedDataset。

首先看看我們的數據集長什麼樣:

,這是典型的 UIT 模型的數據集結構,可以知道涉及到 Dual training。每個子文件夾下都是一系列圖像,且是不對齊的。

我們解釋一下一些 opt 的參數:


opt.dataroot = '__data__/horse2zebra'
opt.mode = 'train'            # 訓練的時候是 train,測試的時候是 test,用來輔助分情況

opt.trainA = 'trainA'
opt.trainB = 'trainB'
opt.testA  = 'testA'
opt.testB  = 'testB'

opt.load_size = 288           # 讀入圖像大小
opt.crop_size = 256           # 將讀入後的圖像隨機裁剪出的 patch 的大小
opt.input_nc  = 3             # 圖像輸入的通道數:RGB-3,灰度圖-1,CMYK-4等等,一般是前兩種情況

下面我們的思路是:(1)在initialize中獲取所有圖像的路徑以確保我們可以訪問它們;(2)在initialize定義圖像數據的基本處理流水線;(3)在__getitem__中定義返回怎麼樣的數據。

 

## SelfDataset
class UnAlignedDataset(BaseDataset):
    ## 重寫 name,返回數據集的名字,一般用不到
    def name(self):
        return "UnAlignedDataset"

    ## 重寫 initialize
    '''
    這裏我們會根據傳入的 opt,獲取數據集的基本信息
    '''
    def initialize(self, opt):
        self.opt = opt                                     #-> 獲取 opt

        ## get dir 
        self.dataroot = opt.dataroot                       #-> 根據 opt 裏的 dataroot 得知數據集的位置

        ## get images                                      #-> 構建圖像子文件夾的路徑
        if opt.mode == 'train':
            dir_A = os.path.join(opt.dataroot, opt.trainA)
            dir_B = os.path.join(opt.dataroot, opt.trainB)
        elif opt.mode == 'test':
            dir_A = os.path.join(opt.dataroot, opt.testA)
            dir_B = os.path.join(opt.dataroot, opt.testB)

        A_paths = os.listdir(dir_A)
        B_paths = os.listdir(dir_B)
        self.length = min(len(A_paths), len(B_paths))      #-> 獲取圖像域 A 和圖像域 B 的所有文件的文件名;並定義數據集大小爲兩個域的大小的較小一個,構建新的屬性 self.length 存儲它

        ## get full path
        for i in range(len(A_paths)):
            A_paths[i] = os.path.join(dir_A, A_paths[i])
        for i in range(len(B_paths)):
            B_paths[i] = os.path.join(dir_B, B_paths[i])   #-> 爲了方便調用,先構建每張圖像的完整路徑(這裏用相對路徑)
        self.A_paths = A_paths 
        self.B_paths = B_paths                             #-> 最後,爲了在其他成員函數中可以直接訪問,我們構建新的屬性來存儲它們

        self.input_nc = self.opt.input_nc                  #-> 當然,對於一些重要的屬性,我們可以從 opt. 中單獨取出,下次用的時候就不需要經過 self.opt.xxx 調用,當然你也可以這麼做,只不過不優雅

        ## define transform
        transforms_list = [transforms.ToTensor(),                  #-> 從numpy到torch.tensor
                           transforms.Normalize((0.5, 0.5, 0.5),   
                                                (0.5, 0.5, 0.5))]  #-> 歸一化到 -1.0~+1.0
        self.transform = transforms.Compose(transforms_list)
        #-> 定義數據處理的過程,注意,經過 torch.utils.data.Dataset 讀入的圖像就已經將像素值轉換爲浮點數,範圍在 0~1.0 之間了,類型是 numpy 數組

    ## Dataset 類的核函數,用 len(dataset_object) 調用,返回數據集的大小
    #-> Dataset 的大小與 DataLoader 的 batch_size 共同決定了一個 epoch 中 迭代次數的多少。即:length_of_dataset // batch_size
    def __len__(self):
        return self.length

    ## 這個核函數是 dataset 被調用時自己內部調用的,每次 dataset 用 next 獲取下一個 batch 的數據的時候,內部會用連續的 batch_size 個索引來取值,並將最後的 batch_size 個結果在第〇個維度拼接在一起。
    '''
    舉個栗子,在圖像中,網絡的輸入一般是:(B, C, H, W);在視頻中,輸入一般是:(B, C, T, H, W)
    而在 __getitem__ 中,我們通過定義它,讓數據返回的數據是:(C, H, W)或者(C, T, H, W)的形式
    '''
    def __getitem__(self, index):
        #-> 首先我們獲取圖像路徑,注意由於我們的任務需要兩個圖像域的圖像
        #-> 我們根據索引對應數據大小的模來定位
        A_pth = self.A_paths[index % self.length]
        B_pth = self.B_paths[index % self.length]    

        #-> 讀入圖像
        x_img = Image.open(A_pth).convert('RGB')                                          #-> 讀入圖像
        x_img = x_img.resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)     #-> 雙線性插值放縮到我們指定的大小(256x256)
        x = self.transform(x_img)                                                         #-> 數據預處理  

        y_img = Image.open(B_pth).convert('RGB')
        y_img = y_img.resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)
        y = self.transform(y_img)

        ## random crop 隨機裁剪
        h, w = x.size(1), x.size(2)
        h_offset = random.randint(0, max(0, h - self.opt.crop_size - 1))
        w_offset = random.randint(0, max(0, w - self.opt.crop_size - 1))
        x = x[:, h_offset:h_offset + self.opt.crop_size, w_offset:w_offset + self.opt.crop_size]
        y = y[:, h_offset:h_offset + self.opt.crop_size, w_offset:w_offset + self.opt.crop_size]

        ## expand to 4-dim tensor
        if self.opt.input_nc == 1:
            # RGB to gray
            tmp_x = x[0, ...] * 0.299 + x[1, ...] * 0.587 + x[2, ...] * 0.114
            x = tmp_x.unsqueeze(0)  # (H,W) -> (C=1,H,W)
            tmp_y = y[0, ...] * 0.299 + y[1, ...] * 0.587 + y[2, ...] * 0.114
            x = tmp_y.unsqueeze(0)  # (H,W) -> (C=1,H,W)

        return {'A': x, 'B':y, 'A_pth': A_pth, 'B_pth': B_pth} 
        '''
        返回什麼樣的數據是我們自定義的,後面我們會看到,我們怎麼使用它:

        for i, data in enumerate(dataset):
            real_x = data['A']
            real_y = data['B']
            ...
        
        可以發現,DataLoader 只負責返回 batch 的數據(數據分不同部分時,各個部分單獨作爲 batch),數據的具體內容自定義的

        '''

好了,我們可以發現,DataSet定義的是如何對要返回的單個數據做處理(像素值歸一化、圖像裁剪、顏色空間等,即所有一切我們在“數字圖像處理”上學到的圖像處理的技術都可以應用);我們發現,transform中有些是可以直接使用的;如果沒有,可以自定義transform的處理函數,也可以像上面RGB轉Gray那樣直接寫在__getitem__中

(三)自定義數據加載器

前面說,DataSet定義的是返回的是單個數據,那麼形成batch的任務、快速加載(分線程)的任務、每個epoch後shuffle(洗牌)數據集的任務等等,都是由DataLoader來完成的。

首先,我們定義一個基本的DataLoader,主要也是爲了引入 opt 所以新增成員函數 initialize 。

#### 0.2 Create a Dataloader
## BaseDataLoader
class BaseDataLoader():
    def __init__(self):
        pass

    def initialize(self, opt):
        self.opt = opt

    def load_data(self):
        return None

下面我們新定義 UnAlignedDataLoader 的數據加載器。

## Dataloader for self data
class UnAlignedDataLoader(BaseDataLoader):
    def name(self):
        return "UnAlignedDataLoader"

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)  # get the copy of opt->self.opt

        # add dataset and nitialize it
        self.dataset = UnAlignedDataset()
        self.dataset.initialize(opt)          # 因爲 initialize 不是 torch.utils.data.Dataset 的核函數,所以我們需要手動調用它,纔算完整初始化

        # define a data loader
        self.dataloader = data.DataLoader(    # 調用 torch.utils.data.DataLoader,
            self.dataset,
            batch_size=opt.batch_size,        # batch 的大小
            shuffle=True,                     # 每個 epoch 後是否洗牌
            num_workers=int(opt.n_threads)    # 使用多少個進程加載數據
        )

    def load_data(self):                      # 返回整個數據加載器本身!!!非常重要
        return self

    def __len__(self):                        # 返回數據集的大小
        return len(self.dataset)

    def __iter__(self):
        for _, d in enumerate(self.dataloader):
            yield d                           # 核函數,用於每次以 batch 遍歷整個數據集,即一個epoch

現在我們可以發現了,其實,許多都是套路!我們需要自定義的最主要的就是 UnAlignedDataset 中,在 initialize 中獲取所有數據的路徑;在 __getitem__ 中讀入數據,並作自定義的處理(放縮、裁剪、像素值歸一化等等),這些處理可以是transform中已有的,也可以是自定義的。

此外的其他三個類,結構與內容都基本不需要怎麼改。

(四)測試

最後就是測試啦~

#### Test data loader
from config import parser
opt = parser.parse_args() ##-> 這是我自定義的,大家需要自己定義,結構大致如下:
'''
# config.py

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
...

'''

data_loader = UnAlignedDataLoader()
data_loader.initialize(opt)

data_set = data_loader.load_data()

for i, data in enumerate(data_set):
    print(i, data['A'].size(), data['B'].size())

輸出如左圖所示。 

至此,pytorch 自定義簡單的數據加載器遍歷數據集的做法介紹到此結束,如有疏漏/錯誤,敬請指出!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章