Pytorch 繼承Dataset加載自己的數據集

1、應用場景

在使用 Pytorch 做分類任務的時候,一般會用自帶的torchvision.datasets.ImageFolder()函數,但是這個對數據存儲方式有要求,不一定適合自己,如果考慮加載自己的數據,就要考慮重寫Dataset類了。

ImageFolder 對數據存儲方式要求:
        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png
        ... ...
        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

2、定製自己的數據加載方式

一般我們不想來回移動數據,知道圖片的路徑即可,告訴模型在哪裏自己去拿,是比較好的方式。所以我們只要繼承Dataset類,重新實現一下即可。

大致方法可分爲三步:
  1. 把圖片的路徑和label整理到文本中(什麼文本都可以,方式也不限,但要方便自己解析)。
  2. 將數據信息,解析,並存到list中。
  3. 重新實現,__getitem__() 函數,讀取每條數據和標籤,並返回。
train.txt --(第1列是數據路徑,第2列標籤)
        root/dog/xxx.png	0
        root/dog/xxy.png	0
        root/dog/xxz.png	0
        root/cat/123.png	1
        root/cat/nsdf3.png	1
        root/cat/asd932_.png	1
具體代碼實現
#!/usr/bin/python
# -*- coding: UTF-8 -*-

from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch


__all__ = ['MyDataset']


class MyDataset(Dataset):

    def __init__(self, dataPath, transform=None, target_transform=None):
        imgsPath = open(dataPath, 'r')
        imgs = []
        for line in imgsPath:
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            label = self.transform(label)
        return img, label

    def __len__(self):
        return len(self.imgs)


if __name__ == '__main__':

    transform_train = transforms.Compose([transforms.Resize(256),  # 重置圖像分辨率
                                          transforms.RandomResizedCrop(224),  # 隨機裁剪
                                          transforms.RandomHorizontalFlip(),  # 以概率p水平翻轉
                                          transforms.RandomVerticalFlip(),  # 以概率p垂直翻轉
                                          transforms.ToTensor(),])
    trainset = MyDataset(dataPath='train.txt', transform=transform_train)  # 訓練集
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
    for step, (tx, ty) in enumerate(trainloader, 0):
        print('---test---', tx, ty)

聲明: 總結學習,有問題或不當之處,可以批評指正哦,謝謝。

優秀的參考鏈接

[1]:https://github.com/tensor-yu/PyTorch_Tutorial
[2]:https://blog.csdn.net/u011995719/article/details/85102770

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章