PyTorch數據Pipeline標準化代碼模板

前言

PyTorch作爲一款流行深度學習框架其熱度大有超越TensorFlow的感覺。根據此前的統計,目前TensorFlow雖然仍然佔據着工業界,但PyTorch在視覺和NLP領域的頂級會議上已呈一統之勢。

這篇文章筆者將和大家聚焦於PyTorch的自定義數據讀取pipeline模板和相關trciks以及如何優化數據讀取的pipeline等。我們從PyTorch的數據對象類Dataset開始。Dataset在PyTorch中的模塊位於utils.data下。

from torch.utils.data import Dataset

本文將圍繞Dataset對象分別從原始模板、torchvision的transforms模塊、使用pandas來輔助讀取、torch內置數據劃分功能和DataLoader來展開闡述。

Dataset原始模板

PyTorch官方爲我們提供了自定義數據讀取的標準化代碼代碼模塊,作爲一個讀取框架,我們這裏稱之爲原始模板。其代碼結構如下:

from torch.utils.data import Dataset
class CustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        
    def __getitem__(self, index):
        # stuff
        return (img, label)
        
    def __len__(self):
        # return examples size
        return count

根據這個標準化的代碼模板,我們只需要根據自己的數據讀取任務,分別往__init__()、__getitem__()和__len__()三個方法裏添加讀取邏輯即可。作爲PyTorch範式下的數據讀取以及爲了後續的data loader,三個方法缺一不可。其中:

  • __init__()函數用於初始化數據讀取邏輯,比如讀取包含標籤和圖片地址的csv文件、定義transform組合等。

  • __getitem__()函數用來返回數據和標籤。目的上是爲了能夠被後續的dataloader所調用。

  • __len__()函數則用於返回樣本數量。

現在我們往這個框架裏填幾行代碼來形成一個簡單的數字案例。創建一個從1到100的數字例子:

from torch.utils.data import Dataset
class CustomDataset(Dataset):
    def __init__(self):
        self.samples = list(range(1, 101))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        return self.samples[idx]
        
if __name__ == '__main__':
    dataset = CustomDataset()
    print(len(dataset))
    print(dataset[50])
    print(dataset[1:100])

添加torchvision.transforms

然後我們來看如何從內存中讀取數據以及如何在讀取過程中嵌入torchvision中的transforms功能。torchvision是一個獨立於torch的關於數據、模型和一些圖像增強操作的輔助庫。主要包括datasets默認數據集模塊、models經典模型模塊、transforms圖像增強模塊以及utils模塊等。在使用torch讀取數據的時候,一般會搭配上transforms模塊對數據進行一些處理和增強工作。

添加了tranforms之後的讀取模塊可以改寫爲:

from torch.utils.data import Dataset
from torchvision import transforms as T

class CustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        ...
        # compose the transforms methods
        self.transform = T.Compose([T.CenterCrop(100),
                                T.ToTensor()])
        
    def __getitem__(self, index):
        # stuff
        ...
        data = # Some data read from a file or image
        # execute the transform
        data = self.transform(data)
        return (img, label)
        
    def __len__(self):
        # return examples size
        return count
        
if __name__ == '__main__':
    # Call the dataset
    custom_dataset = CustomDataset(...)

可以看到,我們使用了Compose方法來把各種數據處理方法聚合到一起進行定義數據轉換方法。通常作爲初始化方法放在__init__()函數下。我們以貓狗圖像數據爲例進行說明。

定義數據讀取方法如下:

class DogCat(Dataset):    
    def __init__(self, root, transforms=None, train=True, val=False):
        """
        get images and execute transforms.
        """
        self.val = val
        imgs = [os.path.join(root, img) for img in os.listdir(root)]
        # train: Cats_Dogs/trainset/cat.1.jpg
        # val: Cats_Dogs/valset/cat.10004.jpg
        imgs = sorted(imgs, key=lambda x: x.split('.')[-2])
        self.imgs = imgs         
        if transforms is None:
            # normalize      
            normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
                                     std = [0.229, 0.224, 0.225])
            # trainset and valset have different data transform 
            # trainset need data augmentation but valset don't.
            # valset

            if self.val:
                self.transforms = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
                ])
            # trainset
            else:
                self.transforms = T.Compose([
                    T.Resize(256),
                    T.RandomResizedCrop(224),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    normalize
                ])
                       
    def __getitem__(self, index):
        """
        return data and label
        """
        img_path = self.imgs[index]
        label = 1 if 'dog' in img_path.split('/')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)
        return data, label
  
    def __len__(self):
        """
        return images size.
        """
        return len(self.imgs)

if __name__ == "__main__":
    train_dataset = DogCat('./Cats_Dogs/trainset/', train=True)
    print(len(train_dataset))
    print(train_dataset[0])

因爲這個數據集已經分好了訓練集和驗證集,所以在讀取和transforms的時候需要進行區分。運行示例如下:

與pandas一起使用

很多時候數據的目錄地址和標籤都是通過csv文件給出的。如下所示:

此時在數據讀取的pipeline中我們需要在__init__()方法中利用pandas把csv文件中包含的圖片地址和標籤融合進去。相應的數據讀取pipeline模板可以改寫爲:

class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = transforms.ToTensor()
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)
        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)
        # Get label of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]
        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    # Call dataset
    dataset =  CustomDatasetFromCSV('./labels.csv')

以mnist_label.csv文件爲示例:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms as T
from PIL import Image
import os
import numpy as np
import pandas as pd

class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file            
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = T.ToTensor()
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # Third column is for an operation indicator
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)
        # Check if there is an operation
        some_operation = self.operation_arr[index]
        # If there is an operation
        if some_operation:
            # Do some operation on image
            # ...
            # ...
            pass

        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)
        # Get label of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]
        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    transform = T.Compose([T.ToTensor()])
    dataset = CustomDatasetFromCSV('./mnist_labels.csv')
    print(len(dataset))
    print(dataset[5])

運行示例如下:

訓練集驗證集劃分

一般來說,爲了模型訓練的穩定,我們需要對數據劃分訓練集和驗證集。torch的Dataset對象也提供了random_split函數作爲數據劃分工具,且劃分結果可直接供後續的DataLoader使用。

以kaggle的花朵數據爲例:

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms as T
from torch.utils.data import random_split

transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.ToTensor()
 ])

dataset = ImageFolder('./flowers_photos', transform=transform)
print(dataset.class_to_idx)

trainset, valset = random_split(dataset, 
                [int(len(dataset)*0.7), len(dataset)-int(len(dataset)*0.7)])

trainloader = DataLoader(dataset=trainset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):
    img, label = img.numpy(), label.numpy()
    print(img, label)

valloader = DataLoader(dataset=valset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):
    img, label = img.numpy(), label.numpy()
    print(img.shape, label)

這裏使用了ImageFolder模塊,可以直接讀取各標籤對應的文件夾,部分運行示例如下:

使用DataLoader

dataset方法寫好之後,我們還需要使用DataLoader將其逐個餵給模型。上一節的數據劃分我們已經用到了DataLoader函數。從本質上來講,DataLoader只是調用了__getitem__()方法並按批次返回數據和標籤。使用方法如下:

from torch.utils.data import DataLoader
from torchvision import transforms as T

if __name__ == "__main__":
    # Define transforms
    transformations = T.Compose([T.ToTensor()])
    # Define custom dataset
    dataset = CustomDatasetFromCSV('./labels.csv')
    # Define data loader
    data_loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)
    for images, labels in data_loader:
        # Feed the data to the model

以上就是PyTorch讀取數據的Pipeline主要方法和流程。基於Dataset對象的基本框架不變,具體細節可自定義化調整。

本文原創首發於公衆號【機器學習實驗室】,開創了【深度學習60講】、【機器學習算法手推30講】和【深度學習100問】三大系列文章。

一個算法工程師的成長之路


長按二維碼.關注機器學習實驗室

機器學習實驗室的近期文章:

參考文獻

【1】https://pytorch.org/docs/stable/data.html

【2】https://towardsdatascience.com/building-efficient-custom-datasets-in-pytorch-2563b946fd9f

【3】https://github.com/utkuozbulak/pytorch-custom-dataset-examples

夕小瑤的賣萌屋

_

關注&星標小夕,帶你解鎖AI祕籍

訂閱號主頁下方「撩一下」有驚喜哦

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章