學習pytorch: 數據加載和處理

簡介

結合官方tutorials源碼以及部分博客寫出此文。

pytorch的數據加載和處理相對容易的多,常見的兩種形式的導入:

  1. 一種是整個數據集都在一個文件夾下,內部再另附一個label文件,說明每個文件夾的狀態,如這個數據庫。這種存放數據的方式可能更適合在非分類問題上得到應用。
  2. 一種則是更適合使用在分類問題上,即把不同種類的數據分爲不同的文件夾存放起來。其形式如下:

    root/ants/xxx.png
    root/ants/xxy.jpeg
    root/ants/xxz.png
    .
    .
    .
    root/bees/123.jpg
    root/bees/nsdf3.png
    root/bees/asd932_.png

本文首先結合官方turorials介紹第一種方法,以瞭解其數據加載的原理;然後以代碼形式簡單介紹第二種方法。其中第二種方法和第一種方法的原理相同,其差別在於第二種方法運用了trochvision中提供的已寫好的工具ImageFolder,因此實現起來更爲簡單。

第一種

Dataset class

torch.utils.data.Dataset是一個抽象類,用戶想要加載自定義的數據只需要繼承這個類,並且覆寫其中的兩個方法即可:

  1. __len__: 覆寫這個方法使得len(dataset)可以返回整個數據集的大小
  2. __getitem__: 覆寫這個方法使得dataset[i]可以返回數據集中第i個樣本
  3. 不覆寫這兩個方法會直接返回錯誤,其源碼如下:
    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

這裏我隨便從網上下載了20張圖像,10張小貓,10張小狗。爲了省事兒(只是想驗證下繼承Dataset類是否好用),我沒有給數據集增加標籤文件,而是直接把1-10號定義爲小貓,11-20號定義爲小狗,這樣會給__len____getitem__減小麻煩,其目錄結構如下:
目錄結構

建立的自定義類如下:

from torch.utils.data import DataLoader, Dataset
from skimage import io, transform
import matplotlib.pyplot as plt 
import os 
import torch
from torchvision import transforms
import numpy as np 

class AnimalData(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
    
    def __len__(self):
        return 20

    def __getitem__(self, idx):
        filenames = os.listdir(self.root_dir)
        filename = filenames[idx]
        img = io.imread(os.path.join(self.root_dir, filename))
        # print filename[:-5]
        if (int(filename[:-5]) > 10):
            lable = np.array([0])
        else:
            lable = np.array([1])
        sample = {'image': img, 'lable':lable}
        
        if self.transform:
            sample = self.transform(sample)
        return sample

Transforms & Compose transforms

可以注意到上一節中AnimalData類中__init__中有個transform參數,這也是這一節中要講清楚的問題。
從網上隨便下載的圖片必然大小不一,而cnn的結構卻要求輸入圖像要有固定的大小;numpy中的圖像通道定義爲H, W, C,而pytorch中的通道定義爲C, H, W; pytorch中輸入數據需要將numpy array改爲tensor類型;輸入數據往往需要歸一化,等等。
基於以上考慮,我們可以自定義一些Callable的類,然後作爲trasform參數傳遞給上一節定義的dataset類。爲了更加方便,torchvision.transforms.Compose提供了Compose類,可以一次性將我們自定義的callable類傳遞給dataset類,直接得到轉換後的數據。
這裏我直接copy教程上的三個類:Rescale, RandomCrop, ToTensor,稍作改動,適應我的數據庫。

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, lable = sample['image'], sample['lable']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for lable because for images,
        # x and y axes are axis 1 and 0 respectively
        # lable = lable * [new_w / w, new_h / h]

        return {'image': img, 'lable': lable}

class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, lable = sample['image'], sample['lable']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        # lable = lable - [left, top]

        return {'image': image, 'lable': lable}

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, lable = sample['image'], sample['lable']
        # print lable
 
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'lable': torch.from_numpy(lable)}

定義好callable類之後,通過torchvision.transforms.Compose將上述三個類結合在一起,傳遞給AnimalData類中的transform參數即可。

trsm = transforms.Compose([Rescale(256),
                            RandomCrop(224),
                            ToTensor()])
data = AnimalData('./all', transform=trsm)

Iterating through the dataset

上一節中得到data實例之後可以通過for循環來一個一個讀取數據,現在這是效率低下的。torch.utils.data.DadaLoader類解決了上述問題。其主要有如下特點:

  • Batching the data
  • Shuffling the data
  • Load the data in parallel using multiprocessing workers.

實現起來也很簡單:

dataloader = DataLoader(data, batch_size=4, shuffle=True, num_workers=4)
for i_batch, bach_data in enumerate(dataloader):
    print i_batch 
    print bach_data['image'].size()
    print bach_data['lable']

第二種

torchvision

pytorch幾乎將上述所有工作都封裝起來供我們使用,其中一個工具就是torchvision.datasets.ImageFolder,用於加載用戶自定義的數據,要求我們的數據要有如下結構:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

torchvision.transforms中也封裝了各種各樣的數據處理的工具,如Resize, ToTensor等等功能供我們使用。
修改我下載的數據庫結構如下:

method2_tree

加載數據代碼如下:

from torchvision import transforms, utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt 

train_data = datasets.ImageFolder('./data1', transform=transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
]))

train_loader = torch.utils.data.DataLoader(train_data,
                                            batch_size=4,
                                            shuffle=True,
                                            )
                                            
print len(train_loader)
for i_batch, img in enumerate(train_loader):
    if i_batch == 0:
        print(img[1])
        fig = plt.figure()
        grid = utils.make_grid(img[0])
        plt.imshow(grid.numpy().transpose((1, 2, 0)))
        plt.show()
    break

結果圖:
make_grid

附錄

最後欣賞一段torchvision源碼

# vision/torchvision/datasets/folder.py

import torch.utils.data as data

from PIL import Image
import os
import os.path

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']


def is_image_file(filename):
    """Checks if a file is an image.
    Args:
        filename (string): path to a file
    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)


def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx


def make_dataset(dir, class_to_idx):
    images = []
    dir = os.path.expanduser(dir)
    for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if is_image_file(fname):
                    path = os.path.join(root, fname)
                    item = (path, class_to_idx[target])
                    images.append(item)

    return images


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


class ImageFolder(data.Dataset):
    """A generic data loader where the images are arranged in this way: ::
        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png
        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png
    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """

    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
        classes, class_to_idx = find_classes(root)
        imgs = make_dataset(root, class_to_idx)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.imgs)

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

參考

[1]. Data Loading and Processing Tutorial
[2]. github: pytorch/torch/utils/data/dataset.py
[3]. github: vision/torchvision/datasets/folder.py
[4]. csdn

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章