Pytorch中提供一個了數據接口datasets,其中封裝了很多公用數據集CIFAR10/100,ImageNet等,可以用下面的接口進行簡單調用,那麼如何使用Pytorch來加載我們自己製作好的trainset呢?我們從源碼來找答案!
train_data = datasets.CIFAR10('./cifa10',train=True,transform=train_tranform,download=True)
從源碼可以看到class cifar 繼承了VisionDataset,VisionDataset是Dataset的子類,並實現了__init__,__len__,__getitem__,三個方法,事實上我們也可以想要實現自定義的數據接口,並使用pytorch進行訓練很簡單,只要繼承基類Dataset並實現上述的三個方法就可以了。
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
對於加載自己的數據集,Pytorch中同樣提供了一個接口,torchvision.datasets.ImageFolder ,但是這個接口相對侷限一些,必須符合他的目錄結構:/root/ids/*.jpg
__init__ 方法
def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):
super(DatasetFolder, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
我們進行簡單調試,看看這個方法都做了什麼?
首先,我們可以看到我們輸入的自定義目錄self.root 是我們定義的訓練集目錄,首先進行__find_classes操作,我們來看看__find_classes 源碼
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
else:
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
返回值classes是一個列表,列表中包含着排好序的id也就是label,而class_to_ids是一個與之序號對應的字典,key是id,value是序號,如下
['102091655-1-201811011700-16', '10209231-1-201811010900-2', '1020962212-2-201811010900-24', '1020966131-3-201811011700-0', '102097752-0-201811010900-6']
{'1020962212-2-201811010900-24': 2, '1020966131-3-201811011700-0': 3, '102097752-0-201811010900-6': 4, '10209231-1-201811010900-2': 1, '102091655-1-201811011700-16': 0}
接下來,用samples接收make_dataset的返回值,其中extensions表示Pytorch支持的圖片編碼格式,與is_valid_file用於驗證數據的合法性。
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
if not ((extensions is None) ^ (is_valid_file is None)):
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x):
return has_file_allowed_extension(x, extensions)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = (path, class_to_idx[target])
images.append(item)
return images
samples樣例如下,是很多個tuple組成的list存儲每個圖片和對應的label
[('test/102091655-1-201811011700-16/10.209.16.55-1-201811011700-201811011703_00000702_crop16.jpg', 0),
('test/102091655-1-201811011700-16/10.209.16.55-1-201811011700-201811011703_00000880_crop16.jpg', 0),
('test/10209231-1-201811010900-2/10.209.23.1-1-201811010900-201811010903_00000092_crop2.jpg', 1),
('test/1020962212-2-201811010900-24/10.209.62.212-2-201811010900-201811010903_00000756_crop24.jpg', 2),
('test/1020966131-3-201811011700-0/10.209.66.131-3-201811011700-201811011703_00000295_crop0.jpg', 3),
('test/1020966131-3-201811011700-0/10.209.66.131-3-201811011700-201811011703_00000302_crop0.jpg', 3),
('test/102097752-0-201811010900-6/10.209.77.52-0-201811010900-201811010903_00000395_crop6.jpg', 4),
('test/102097752-0-201811010900-6/10.209.77.52-0-201811010900-201811010903_00000434_crop6.jpg', 4)]
接下來,還有一個loader的賦值操作,是一個函數參數,通常我們使用pil_loader函數進行加載。
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
__getitem__ 與__len__
get_item 是Dataloader的調度基礎,輸入參數是index索引,返回的是經過transform過的圖片和label,len函數返回的是數據集的length
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
DIY Interface(自定義接口)
如果你可以看懂這幾個函數的用法,就可以開始定義自己需要的數據接口了。假設我們的train.txt ,val.txt,test.txt 中的格式如下,想一下我們該如何自定義上文中的三種方法呢?
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002532_crop23.jpg 1
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002521_crop23.jpg 1
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002535_crop23.jpg 2
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002528_crop23.jpg 2
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002523_crop23.jpg 2
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002529_crop23.jpg 3
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002527_crop23.jpg 3
/20190424/200001320002208-1556067600-69/CJ145YWJMK1-32130200001320002208-1556067600_00000833_crop69.jpg 3
/20190424/200001320002208-1556067600-69/CJ145YWJMK1-32130200001320002208-1556067600_00000834_crop69.jpg 4
/20190424/00001320000179-1556104800-30/SZ009SZZP3-32130200001320000179-1556104800_00001954_crop30.jpg 4
下面是我給的僞代碼,沒有調試,主要是爲了說明這個道理!
# _*_ coding:utf-8 _*_
import torch.utils.data as data
class trueData(data.Dataset):
def __init__(self,root,txt_path,dataset=None,transforms = None,loader=default_loader):
with open(txt_path) as data_input:
lines = data_input.readlines()
self.images = [os.path.join(root,line.split('\t')[0]) for line in lines]
self.labels = [os.path.join(root,line.split('\t')[1]) for line in lines]
self.transform = transforms
self.dataset = dataset
self.loader = loader
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img = self.images[index]
label = self.labels[index]
img_data = self.loader(img)
if self.transform:
try:
img = self.transform(img)
except:
print "error in transform"
return img,label
調用方法可以這麼寫,這樣就完成了自定義數據的加載過程。
image_datasets = {x: customData(img_path='/home/badoo/person',
txt_path=('/home/badoo/train_list/' + x + '.txt'),
data_transforms=data_transforms,
dataset=x) for x in ['train', 'val']}
DataLoader
在我們訓練過程中,前面有講過通常輸入的是tensor格式[N,C,W,H],在Pytorch中提供了一個API批量加載 DataLoader,並將結果進行transform和toTensor()以及BatchNorm等操作,源代碼可供參考
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=batch_size,
shuffle=True) for x in ['train', 'val']}
1、dataset,這個就是PyTorch已有的數據讀取接口(比如torchvision.datasets.ImageFolder)或者自定義的數據接口的輸出,該輸出要麼是torch.utils.data.Dataset類的對象,要麼是繼承自torch.utils.data.Dataset類的自定義類的對象。
2、batch_size,根據具體情況設置即可。
3、shuffle,一般在訓練數據中會採用。
4、collate_fn,是用來處理不同情況下的輸入dataset的封裝,一般採用默認即可,除非你自定義的數據讀取輸出非常少見。
5、batch_sampler,從註釋可以看出,其和batch_size、shuffle等參數是互斥的,一般採用默認。
6、sampler,從代碼可以看出,其和shuffle是互斥的,一般默認即可。
7、num_workers,從註釋可以看出這個參數必須大於等於0,0的話表示數據導入在主進程中進行,其他大於0的數表示通過多個進程來導入數據,可以加快數據導入速度。
8、pin_memory,註釋寫得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一個數據拷貝的問題。
9、timeout,是用來設置數據讀取的超時時間的,但超過這個時間還沒讀取到數據的話就會報錯。
下面是兩種接口調用方法,我更喜歡第2種 ^_^
#寫法1:
train_data=torch.utils.data.DataLoader(...)
for i, (input, target) in enumerate(train_data):
...
#寫法2
train_load = torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True,num_workers=8)
for i,(ids,labels) in enumerate(train_load):
...
堅持一件事或許很難,但堅持下來一定很酷!^_^