Pytorch(筆記9)--讀取自定義數據

      Pytorch中提供一個了數據接口datasets,其中封裝了很多公用數據集CIFAR10/100,ImageNet等,可以用下面的接口進行簡單調用,那麼如何使用Pytorch來加載我們自己製作好的trainset呢?我們從源碼來找答案!

      train_data = datasets.CIFAR10('./cifa10',train=True,transform=train_tranform,download=True)

     從源碼可以看到class cifar  繼承了VisionDataset,VisionDataset是Dataset的子類,並實現了__init__,__len__,__getitem__,三個方法,事實上我們也可以想要實現自定義的數據接口,並使用pytorch進行訓練很簡單,只要繼承基類Dataset並實現上述的三個方法就可以了。

class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

      對於加載自己的數據集,Pytorch中同樣提供了一個接口,torchvision.datasets.ImageFolder ,但是這個接口相對侷限一些,必須符合他的目錄結構:/root/ids/*.jpg

__init__ 方法

def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):
        super(DatasetFolder, self).__init__(root)
        self.transform = transform
        self.target_transform = target_transform
        classes, class_to_idx = self._find_classes(self.root)
        samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
        if len(samples) == 0:
            raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
                                "Supported extensions are: " + ",".join(extensions)))

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

      我們進行簡單調試,看看這個方法都做了什麼?

      首先,我們可以看到我們輸入的自定義目錄self.root 是我們定義的訓練集目錄,首先進行__find_classes操作,我們來看看__find_classes 源碼

 def _find_classes(self, dir):
        """
        Finds the class folders in a dataset.
        Args:
            dir (string): Root directory path.
        Returns:
            tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
        Ensures:
            No class is a subdirectory of another.
        """
        if sys.version_info >= (3, 5):
            # Faster and available in Python 3.5 and above
            classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx

       返回值classes是一個列表,列表中包含着排好序的id也就是label,而class_to_ids是一個與之序號對應的字典,key是id,value是序號,如下

['102091655-1-201811011700-16', '10209231-1-201811010900-2', '1020962212-2-201811010900-24', '1020966131-3-201811011700-0', '102097752-0-201811010900-6']

{'1020962212-2-201811010900-24': 2, '1020966131-3-201811011700-0': 3, '102097752-0-201811010900-6': 4, '10209231-1-201811010900-2': 1, '102091655-1-201811011700-16': 0}

      接下來,用samples接收make_dataset的返回值,其中extensions表示Pytorch支持的圖片編碼格式,與is_valid_file用於驗證數據的合法性。

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')

def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
    images = []
    dir = os.path.expanduser(dir)
    if not ((extensions is None) ^ (is_valid_file is None)):
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
    if extensions is not None:
        def is_valid_file(x):
            return has_file_allowed_extension(x, extensions)
    for target in sorted(class_to_idx.keys()):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue
        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_valid_file(path):
                    item = (path, class_to_idx[target])
                    images.append(item)
    return images

    samples樣例如下,是很多個tuple組成的list存儲每個圖片和對應的label 

[('test/102091655-1-201811011700-16/10.209.16.55-1-201811011700-201811011703_00000702_crop16.jpg', 0),

('test/102091655-1-201811011700-16/10.209.16.55-1-201811011700-201811011703_00000880_crop16.jpg', 0),

('test/10209231-1-201811010900-2/10.209.23.1-1-201811010900-201811010903_00000092_crop2.jpg', 1),

('test/1020962212-2-201811010900-24/10.209.62.212-2-201811010900-201811010903_00000756_crop24.jpg', 2),

('test/1020966131-3-201811011700-0/10.209.66.131-3-201811011700-201811011703_00000295_crop0.jpg', 3),

('test/1020966131-3-201811011700-0/10.209.66.131-3-201811011700-201811011703_00000302_crop0.jpg', 3),

('test/102097752-0-201811010900-6/10.209.77.52-0-201811010900-201811010903_00000395_crop6.jpg', 4),

('test/102097752-0-201811010900-6/10.209.77.52-0-201811010900-201811010903_00000434_crop6.jpg', 4)]

       接下來,還有一個loader的賦值操作,是一個函數參數,通常我們使用pil_loader函數進行加載。

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

 

__getitem__ 與__len__

     get_item 是Dataloader的調度基礎,輸入參數是index索引,返回的是經過transform過的圖片和label,len函數返回的是數據集的length

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self):
        return len(self.samples)

 

DIY Interface(自定義接口)

      如果你可以看懂這幾個函數的用法,就可以開始定義自己需要的數據接口了。假設我們的train.txt ,val.txt,test.txt 中的格式如下,想一下我們該如何自定義上文中的三種方法呢?

/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002532_crop23.jpg	1
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002521_crop23.jpg	1
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002535_crop23.jpg	2
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002528_crop23.jpg	2
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002523_crop23.jpg	2
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002529_crop23.jpg	3
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002527_crop23.jpg	3
/20190424/200001320002208-1556067600-69/CJ145YWJMK1-32130200001320002208-1556067600_00000833_crop69.jpg	3
/20190424/200001320002208-1556067600-69/CJ145YWJMK1-32130200001320002208-1556067600_00000834_crop69.jpg	4
/20190424/00001320000179-1556104800-30/SZ009SZZP3-32130200001320000179-1556104800_00001954_crop30.jpg	4

        下面是我給的僞代碼,沒有調試,主要是爲了說明這個道理!

# _*_ coding:utf-8 _*_
import torch.utils.data as data

class trueData(data.Dataset):
    def __init__(self,root,txt_path,dataset=None,transforms = None,loader=default_loader):
        with open(txt_path) as data_input:
            lines = data_input.readlines()
            self.images = [os.path.join(root,line.split('\t')[0]) for line in lines] 
            self.labels = [os.path.join(root,line.split('\t')[1]) for line in lines]
        self.transform = transforms
        self.dataset = dataset
        self.loader = loader
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img = self.images[index]
        label = self.labels[index]
        img_data = self.loader(img)
        if self.transform:
            try:
                img = self.transform(img)
            except:
                print "error in transform"
        return img,label
        

       調用方法可以這麼寫,這樣就完成了自定義數據的加載過程。

 image_datasets = {x: customData(img_path='/home/badoo/person',
                                    txt_path=('/home/badoo/train_list/' + x + '.txt'),
                                    data_transforms=data_transforms,
                                    dataset=x) for x in ['train', 'val']}

 DataLoader

    在我們訓練過程中,前面有講過通常輸入的是tensor格式[N,C,W,H],在Pytorch中提供了一個API批量加載 DataLoader,並將結果進行transform和toTensor()以及BatchNorm等操作,源代碼可供參考

 dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=batch_size,
                                                 shuffle=True) for x in ['train', 'val']}

參數部分

1、dataset,這個就是PyTorch已有的數據讀取接口(比如torchvision.datasets.ImageFolder)或者自定義的數據接口的輸出,該輸出要麼是torch.utils.data.Dataset類的對象,要麼是繼承自torch.utils.data.Dataset類的自定義類的對象。

2、batch_size,根據具體情況設置即可。

3、shuffle,一般在訓練數據中會採用。

4、collate_fn,是用來處理不同情況下的輸入dataset的封裝,一般採用默認即可,除非你自定義的數據讀取輸出非常少見。

5、batch_sampler,從註釋可以看出,其和batch_size、shuffle等參數是互斥的,一般採用默認。

6、sampler,從代碼可以看出,其和shuffle是互斥的,一般默認即可。

7、num_workers,從註釋可以看出這個參數必須大於等於0,0的話表示數據導入在主進程中進行,其他大於0的數表示通過多個進程來導入數據,可以加快數據導入速度。

8、pin_memory,註釋寫得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一個數據拷貝的問題。 

9、timeout,是用來設置數據讀取的超時時間的,但超過這個時間還沒讀取到數據的話就會報錯。

     下面是兩種接口調用方法,我更喜歡第2種 ^_^

#寫法1:
train_data=torch.utils.data.DataLoader(...) 
for i, (input, target) in enumerate(train_data): 
... 

#寫法2
train_load = torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True,num_workers=8)
for i,(ids,labels) in enumerate(train_load):
...

          堅持一件事或許很難,但堅持下來一定很酷!^_^

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章