pytorch 读取数据方法总结

用pytorch读取数据,确实要比tensorflow简单,但是也得熟悉半个小时左右.

下面总结下我的体验,直接用代码
(1)torch.utils.data.Dataset
(2)torch.utils.data.DataLoader
这两个类搭配的数据读取代码:

import os
import glob
import cv2
import numpy as np
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader


#第一种数据读取方式
transform = T.Compose([
    T.Resize(224),
    T.CenterCrop(224),
    T.RandomHorizontalFlip(),
    T.RandomSizedCrop(224),
    T.ToTensor(),#将图片从0-255变为0-1
    T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])#标准化到[-1,1]
])

class Test_Data(Dataset):
    def __init__(self,data_root,mask_root,transforms=None):

        data_image = glob.glob(data_root+'/*.jpg')
        self.data_image = data_image

        mask_image = glob.glob(mask_root+'/*.jpg')
        self.mask_image = mask_image

        self.transforms = transforms

    def __getitem__(self, index):

        data_image_path = self.data_image[index]
        mask_image_path = self.mask_image[index]

        image_data = cv2.imread(data_image_path,-1)
        mask_data = cv2.imread(mask_image_path,-1)
        if self.transforms:
            image_data = self.transforms(image_data)
            mask_data = self.transforms(mask_data)

        return image_data,mask_data

    def __len__(self):
        return len(self.data_image)

dataset = Test_Data(data_root='../../test_image/37_simple/0001/data',mask_root='../../test_image/37_simple/0001/mask')

#第一种调用,不常用
for data,mask in dataset:
     print(data.shape,mask.shape)

下面两种应该在训练过程中更加好:
test_data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=1, pin_memory=True, shuffle=True,drop_last=True)
#第一种
# for i,data in enumerate(test_data_loader,0):
#     print(data[0].shape,'..')
#     print(data[1].shape,'...')
#第2种
for data_batch,mask_batch in test_data_loader:
    print(data_batch.size(),mask_batch.size())

还有一种通过from torchvision.datasets import ImageFolder来访问文件数据
不过,我觉得这种更加适合在分类任务中应用

from torchvision.datasets import ImageFolder
#
dataset_data = ImageFolder('../../test_image/37_simple/0001/',transform=None)
#
print(dataset_data.class_to_idx)
print(dataset_data.img)

这种还没有用过

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章