前言
PyTorch作爲一款流行深度學習框架其熱度大有超越TensorFlow的感覺。根據此前的統計,目前TensorFlow雖然仍然佔據着工業界,但PyTorch在視覺和NLP領域的頂級會議上已呈一統之勢。
這篇文章筆者將和大家聚焦於PyTorch的自定義數據讀取pipeline模板和相關trciks以及如何優化數據讀取的pipeline等。我們從PyTorch的數據對象類Dataset開始。Dataset在PyTorch中的模塊位於utils.data下。
from torch.utils.data import Dataset
本文將圍繞Dataset對象分別從原始模板、torchvision的transforms模塊、使用pandas來輔助讀取、torch內置數據劃分功能和DataLoader來展開闡述。
Dataset原始模板
PyTorch官方爲我們提供了自定義數據讀取的標準化代碼代碼模塊,作爲一個讀取框架,我們這裏稱之爲原始模板。其代碼結構如下:
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, ...):
# stuff
def __getitem__(self, index):
# stuff
return (img, label)
def __len__(self):
# return examples size
return count
根據這個標準化的代碼模板,我們只需要根據自己的數據讀取任務,分別往__init__()、__getitem__()和__len__()三個方法裏添加讀取邏輯即可。作爲PyTorch範式下的數據讀取以及爲了後續的data loader,三個方法缺一不可。其中:
__init__()函數用於初始化數據讀取邏輯,比如讀取包含標籤和圖片地址的csv文件、定義transform組合等。
__getitem__()函數用來返回數據和標籤。目的上是爲了能夠被後續的dataloader所調用。
__len__()函數則用於返回樣本數量。
現在我們往這個框架裏填幾行代碼來形成一個簡單的數字案例。創建一個從1到100的數字例子:
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self):
self.samples = list(range(1, 101))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
if __name__ == '__main__':
dataset = CustomDataset()
print(len(dataset))
print(dataset[50])
print(dataset[1:100])
添加torchvision.transforms
然後我們來看如何從內存中讀取數據以及如何在讀取過程中嵌入torchvision中的transforms功能。torchvision是一個獨立於torch的關於數據、模型和一些圖像增強操作的輔助庫。主要包括datasets默認數據集模塊、models經典模型模塊、transforms圖像增強模塊以及utils模塊等。在使用torch讀取數據的時候,一般會搭配上transforms模塊對數據進行一些處理和增強工作。
添加了tranforms之後的讀取模塊可以改寫爲:
from torch.utils.data import Dataset
from torchvision import transforms as T
class CustomDataset(Dataset):
def __init__(self, ...):
# stuff
...
# compose the transforms methods
self.transform = T.Compose([T.CenterCrop(100),
T.ToTensor()])
def __getitem__(self, index):
# stuff
...
data = # Some data read from a file or image
# execute the transform
data = self.transform(data)
return (img, label)
def __len__(self):
# return examples size
return count
if __name__ == '__main__':
# Call the dataset
custom_dataset = CustomDataset(...)
可以看到,我們使用了Compose方法來把各種數據處理方法聚合到一起進行定義數據轉換方法。通常作爲初始化方法放在__init__()函數下。我們以貓狗圖像數據爲例進行說明。
定義數據讀取方法如下:
class DogCat(Dataset):
def __init__(self, root, transforms=None, train=True, val=False):
"""
get images and execute transforms.
"""
self.val = val
imgs = [os.path.join(root, img) for img in os.listdir(root)]
# train: Cats_Dogs/trainset/cat.1.jpg
# val: Cats_Dogs/valset/cat.10004.jpg
imgs = sorted(imgs, key=lambda x: x.split('.')[-2])
self.imgs = imgs
if transforms is None:
# normalize
normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
# trainset and valset have different data transform
# trainset need data augmentation but valset don't.
# valset
if self.val:
self.transforms = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
# trainset
else:
self.transforms = T.Compose([
T.Resize(256),
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])
def __getitem__(self, index):
"""
return data and label
"""
img_path = self.imgs[index]
label = 1 if 'dog' in img_path.split('/')[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
return data, label
def __len__(self):
"""
return images size.
"""
return len(self.imgs)
if __name__ == "__main__":
train_dataset = DogCat('./Cats_Dogs/trainset/', train=True)
print(len(train_dataset))
print(train_dataset[0])
因爲這個數據集已經分好了訓練集和驗證集,所以在讀取和transforms的時候需要進行區分。運行示例如下:
與pandas一起使用
很多時候數據的目錄地址和標籤都是通過csv文件給出的。如下所示:
此時在數據讀取的pipeline中我們需要在__init__()方法中利用pandas把csv文件中包含的圖片地址和標籤融合進去。相應的數據讀取pipeline模板可以改寫爲:
class CustomDatasetFromCSV(Dataset):
def __init__(self, csv_path):
"""
Args:
csv_path (string): path to csv file
transform: pytorch transforms for transforms and tensor conversion
"""
# Transforms
self.to_tensor = transforms.ToTensor()
# Read the csv file
self.data_info = pd.read_csv(csv_path, header=None)
# First column contains the image paths
self.image_arr = np.asarray(self.data_info.iloc[:, 0])
# Second column is the labels
self.label_arr = np.asarray(self.data_info.iloc[:, 1])
# Calculate len
self.data_len = len(self.data_info.index)
def __getitem__(self, index):
# Get image name from the pandas df
single_image_name = self.image_arr[index]
# Open image
img_as_img = Image.open(single_image_name)
# Transform image to tensor
img_as_tensor = self.to_tensor(img_as_img)
# Get label of the image based on the cropped pandas column
single_image_label = self.label_arr[index]
return (img_as_tensor, single_image_label)
def __len__(self):
return self.data_len
if __name__ == "__main__":
# Call dataset
dataset = CustomDatasetFromCSV('./labels.csv')
以mnist_label.csv文件爲示例:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms as T
from PIL import Image
import os
import numpy as np
import pandas as pd
class CustomDatasetFromCSV(Dataset):
def __init__(self, csv_path):
"""
Args:
csv_path (string): path to csv file
transform: pytorch transforms for transforms and tensor conversion
"""
# Transforms
self.to_tensor = T.ToTensor()
# Read the csv file
self.data_info = pd.read_csv(csv_path, header=None)
# First column contains the image paths
self.image_arr = np.asarray(self.data_info.iloc[:, 0])
# Second column is the labels
self.label_arr = np.asarray(self.data_info.iloc[:, 1])
# Third column is for an operation indicator
self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
# Calculate len
self.data_len = len(self.data_info.index)
def __getitem__(self, index):
# Get image name from the pandas df
single_image_name = self.image_arr[index]
# Open image
img_as_img = Image.open(single_image_name)
# Check if there is an operation
some_operation = self.operation_arr[index]
# If there is an operation
if some_operation:
# Do some operation on image
# ...
# ...
pass
# Transform image to tensor
img_as_tensor = self.to_tensor(img_as_img)
# Get label of the image based on the cropped pandas column
single_image_label = self.label_arr[index]
return (img_as_tensor, single_image_label)
def __len__(self):
return self.data_len
if __name__ == "__main__":
transform = T.Compose([T.ToTensor()])
dataset = CustomDatasetFromCSV('./mnist_labels.csv')
print(len(dataset))
print(dataset[5])
運行示例如下:
訓練集驗證集劃分
一般來說,爲了模型訓練的穩定,我們需要對數據劃分訓練集和驗證集。torch的Dataset對象也提供了random_split函數作爲數據劃分工具,且劃分結果可直接供後續的DataLoader使用。
以kaggle的花朵數據爲例:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms as T
from torch.utils.data import random_split
transform = T.Compose([
T.Resize((224, 224)),
T.RandomHorizontalFlip(),
T.ToTensor()
])
dataset = ImageFolder('./flowers_photos', transform=transform)
print(dataset.class_to_idx)
trainset, valset = random_split(dataset,
[int(len(dataset)*0.7), len(dataset)-int(len(dataset)*0.7)])
trainloader = DataLoader(dataset=trainset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):
img, label = img.numpy(), label.numpy()
print(img, label)
valloader = DataLoader(dataset=valset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):
img, label = img.numpy(), label.numpy()
print(img.shape, label)
這裏使用了ImageFolder模塊,可以直接讀取各標籤對應的文件夾,部分運行示例如下:
使用DataLoader
dataset方法寫好之後,我們還需要使用DataLoader將其逐個餵給模型。上一節的數據劃分我們已經用到了DataLoader函數。從本質上來講,DataLoader只是調用了__getitem__()方法並按批次返回數據和標籤。使用方法如下:
from torch.utils.data import DataLoader
from torchvision import transforms as T
if __name__ == "__main__":
# Define transforms
transformations = T.Compose([T.ToTensor()])
# Define custom dataset
dataset = CustomDatasetFromCSV('./labels.csv')
# Define data loader
data_loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)
for images, labels in data_loader:
# Feed the data to the model
以上就是PyTorch讀取數據的Pipeline主要方法和流程。基於Dataset對象的基本框架不變,具體細節可自定義化調整。
本文原創首發於公衆號【機器學習實驗室】,開創了【深度學習60講】、【機器學習算法手推30講】和【深度學習100問】三大系列文章。
一個算法工程師的成長之路
長按二維碼.關注機器學習實驗室
機器學習實驗室的近期文章:
參考文獻
【1】https://pytorch.org/docs/stable/data.html
【2】https://towardsdatascience.com/building-efficient-custom-datasets-in-pytorch-2563b946fd9f
【3】https://github.com/utkuozbulak/pytorch-custom-dataset-examples
夕小瑤的賣萌屋
關注&星標小夕,帶你解鎖AI祕籍
訂閱號主頁下方「撩一下」有驚喜哦