在使用 pytorch 構建深度學習相關的項目時,通常需要經過【模型結構】-【損失函數定義】-【數據設置】-【訓練代碼】-【log、驗證、可視化與 checkpoints】。其中,【數據設置】往往因爲項目/任務的不同,需要自定義合適的DataLoader(數據加載器)。
本文即將介紹 torch.utils.data 中的 Dataset 與 Dataloader 的基本用法,以 Unpaired Image-to-Image Translation 任務的非成對圖像數據的加載爲例,講解 pytorch 如何自定義數據加載器。
下面的代碼均在文件 dataset.py 中。
(一)引入必須的包
# -*- coding:utf-8 -*-
import torch.utils.data as data
import torchvision.transforms as transforms
import os
from PIL import Image
import random
import torch
import numpy as np
(二)自定義數據集 Dataset
#### 01. Create a dataset
## BaseDataset
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return "BaseDataset"
def initialize(self, opt):
self.opt = opt
'''
定義一些公用的屬性/函數;一般的,torch.utils.data.Dataset 本身已經包含了很多屬性,如 __len__, __getitem__ 等。
一般我們會新增一個成員函數 name 和 initialize,分別用於:
1)name:沒有任何意義,純屬裝 B
2)在 pytorch 中,我們經常會使用到 parser,即一個能夠從命令行賦予超參數值的輔助類,我們在代碼中實例化它的一個對象爲 "opt" ,而且,諸如 opt.img_size, opt.batch_size 這樣的參數是與 data 相關的,所以我們通常會在這個函數引入 opt,並將它作爲自己一個屬性 self.opt,如此,我們就可以隨時訪問所有的超參數了。
'''
下面我們要自定義數據集 UnAlignedDataset。
首先看看我們的數據集長什麼樣:
,這是典型的 UIT 模型的數據集結構,可以知道涉及到 Dual training。每個子文件夾下都是一系列圖像,且是不對齊的。
我們解釋一下一些 opt 的參數:
opt.dataroot = '__data__/horse2zebra'
opt.mode = 'train' # 訓練的時候是 train,測試的時候是 test,用來輔助分情況
opt.trainA = 'trainA'
opt.trainB = 'trainB'
opt.testA = 'testA'
opt.testB = 'testB'
opt.load_size = 288 # 讀入圖像大小
opt.crop_size = 256 # 將讀入後的圖像隨機裁剪出的 patch 的大小
opt.input_nc = 3 # 圖像輸入的通道數:RGB-3,灰度圖-1,CMYK-4等等,一般是前兩種情況
下面我們的思路是:(1)在initialize中獲取所有圖像的路徑以確保我們可以訪問它們;(2)在initialize定義圖像數據的基本處理流水線;(3)在__getitem__中定義返回怎麼樣的數據。
## SelfDataset
class UnAlignedDataset(BaseDataset):
## 重寫 name,返回數據集的名字,一般用不到
def name(self):
return "UnAlignedDataset"
## 重寫 initialize
'''
這裏我們會根據傳入的 opt,獲取數據集的基本信息
'''
def initialize(self, opt):
self.opt = opt #-> 獲取 opt
## get dir
self.dataroot = opt.dataroot #-> 根據 opt 裏的 dataroot 得知數據集的位置
## get images #-> 構建圖像子文件夾的路徑
if opt.mode == 'train':
dir_A = os.path.join(opt.dataroot, opt.trainA)
dir_B = os.path.join(opt.dataroot, opt.trainB)
elif opt.mode == 'test':
dir_A = os.path.join(opt.dataroot, opt.testA)
dir_B = os.path.join(opt.dataroot, opt.testB)
A_paths = os.listdir(dir_A)
B_paths = os.listdir(dir_B)
self.length = min(len(A_paths), len(B_paths)) #-> 獲取圖像域 A 和圖像域 B 的所有文件的文件名;並定義數據集大小爲兩個域的大小的較小一個,構建新的屬性 self.length 存儲它
## get full path
for i in range(len(A_paths)):
A_paths[i] = os.path.join(dir_A, A_paths[i])
for i in range(len(B_paths)):
B_paths[i] = os.path.join(dir_B, B_paths[i]) #-> 爲了方便調用,先構建每張圖像的完整路徑(這裏用相對路徑)
self.A_paths = A_paths
self.B_paths = B_paths #-> 最後,爲了在其他成員函數中可以直接訪問,我們構建新的屬性來存儲它們
self.input_nc = self.opt.input_nc #-> 當然,對於一些重要的屬性,我們可以從 opt. 中單獨取出,下次用的時候就不需要經過 self.opt.xxx 調用,當然你也可以這麼做,只不過不優雅
## define transform
transforms_list = [transforms.ToTensor(), #-> 從numpy到torch.tensor
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))] #-> 歸一化到 -1.0~+1.0
self.transform = transforms.Compose(transforms_list)
#-> 定義數據處理的過程,注意,經過 torch.utils.data.Dataset 讀入的圖像就已經將像素值轉換爲浮點數,範圍在 0~1.0 之間了,類型是 numpy 數組
## Dataset 類的核函數,用 len(dataset_object) 調用,返回數據集的大小
#-> Dataset 的大小與 DataLoader 的 batch_size 共同決定了一個 epoch 中 迭代次數的多少。即:length_of_dataset // batch_size
def __len__(self):
return self.length
## 這個核函數是 dataset 被調用時自己內部調用的,每次 dataset 用 next 獲取下一個 batch 的數據的時候,內部會用連續的 batch_size 個索引來取值,並將最後的 batch_size 個結果在第〇個維度拼接在一起。
'''
舉個栗子,在圖像中,網絡的輸入一般是:(B, C, H, W);在視頻中,輸入一般是:(B, C, T, H, W)
而在 __getitem__ 中,我們通過定義它,讓數據返回的數據是:(C, H, W)或者(C, T, H, W)的形式
'''
def __getitem__(self, index):
#-> 首先我們獲取圖像路徑,注意由於我們的任務需要兩個圖像域的圖像
#-> 我們根據索引對應數據大小的模來定位
A_pth = self.A_paths[index % self.length]
B_pth = self.B_paths[index % self.length]
#-> 讀入圖像
x_img = Image.open(A_pth).convert('RGB') #-> 讀入圖像
x_img = x_img.resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC) #-> 雙線性插值放縮到我們指定的大小(256x256)
x = self.transform(x_img) #-> 數據預處理
y_img = Image.open(B_pth).convert('RGB')
y_img = y_img.resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)
y = self.transform(y_img)
## random crop 隨機裁剪
h, w = x.size(1), x.size(2)
h_offset = random.randint(0, max(0, h - self.opt.crop_size - 1))
w_offset = random.randint(0, max(0, w - self.opt.crop_size - 1))
x = x[:, h_offset:h_offset + self.opt.crop_size, w_offset:w_offset + self.opt.crop_size]
y = y[:, h_offset:h_offset + self.opt.crop_size, w_offset:w_offset + self.opt.crop_size]
## expand to 4-dim tensor
if self.opt.input_nc == 1:
# RGB to gray
tmp_x = x[0, ...] * 0.299 + x[1, ...] * 0.587 + x[2, ...] * 0.114
x = tmp_x.unsqueeze(0) # (H,W) -> (C=1,H,W)
tmp_y = y[0, ...] * 0.299 + y[1, ...] * 0.587 + y[2, ...] * 0.114
x = tmp_y.unsqueeze(0) # (H,W) -> (C=1,H,W)
return {'A': x, 'B':y, 'A_pth': A_pth, 'B_pth': B_pth}
'''
返回什麼樣的數據是我們自定義的,後面我們會看到,我們怎麼使用它:
for i, data in enumerate(dataset):
real_x = data['A']
real_y = data['B']
...
可以發現,DataLoader 只負責返回 batch 的數據(數據分不同部分時,各個部分單獨作爲 batch),數據的具體內容自定義的
'''
好了,我們可以發現,DataSet定義的是如何對要返回的單個數據做處理(像素值歸一化、圖像裁剪、顏色空間等,即所有一切我們在“數字圖像處理”上學到的圖像處理的技術都可以應用);我們發現,transform中有些是可以直接使用的;如果沒有,可以自定義transform的處理函數,也可以像上面RGB轉Gray那樣直接寫在__getitem__中!
(三)自定義數據加載器
前面說,DataSet定義的是返回的是單個數據,那麼形成batch的任務、快速加載(分線程)的任務、每個epoch後shuffle(洗牌)數據集的任務等等,都是由DataLoader來完成的。
首先,我們定義一個基本的DataLoader,主要也是爲了引入 opt 所以新增成員函數 initialize 。
#### 0.2 Create a Dataloader
## BaseDataLoader
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
def load_data(self):
return None
下面我們新定義 UnAlignedDataLoader 的數據加載器。
## Dataloader for self data
class UnAlignedDataLoader(BaseDataLoader):
def name(self):
return "UnAlignedDataLoader"
def initialize(self, opt):
BaseDataLoader.initialize(self, opt) # get the copy of opt->self.opt
# add dataset and nitialize it
self.dataset = UnAlignedDataset()
self.dataset.initialize(opt) # 因爲 initialize 不是 torch.utils.data.Dataset 的核函數,所以我們需要手動調用它,纔算完整初始化
# define a data loader
self.dataloader = data.DataLoader( # 調用 torch.utils.data.DataLoader,
self.dataset,
batch_size=opt.batch_size, # batch 的大小
shuffle=True, # 每個 epoch 後是否洗牌
num_workers=int(opt.n_threads) # 使用多少個進程加載數據
)
def load_data(self): # 返回整個數據加載器本身!!!非常重要
return self
def __len__(self): # 返回數據集的大小
return len(self.dataset)
def __iter__(self):
for _, d in enumerate(self.dataloader):
yield d # 核函數,用於每次以 batch 遍歷整個數據集,即一個epoch
現在我們可以發現了,其實,許多都是套路!我們需要自定義的最主要的就是 UnAlignedDataset 中,在 initialize 中獲取所有數據的路徑;在 __getitem__ 中讀入數據,並作自定義的處理(放縮、裁剪、像素值歸一化等等),這些處理可以是transform中已有的,也可以是自定義的。
此外的其他三個類,結構與內容都基本不需要怎麼改。
(四)測試
最後就是測試啦~
#### Test data loader
from config import parser
opt = parser.parse_args() ##-> 這是我自定義的,大家需要自己定義,結構大致如下:
'''
# config.py
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
...
'''
data_loader = UnAlignedDataLoader()
data_loader.initialize(opt)
data_set = data_loader.load_data()
for i, data in enumerate(data_set):
print(i, data['A'].size(), data['B'].size())
輸出如左圖所示。
至此,pytorch 自定義簡單的數據加載器遍歷數據集的做法介紹到此結束,如有疏漏/錯誤,敬請指出!