pytorch 讀取數據方法總結

用pytorch讀取數據,確實要比tensorflow簡單,但是也得熟悉半個小時左右.

下面總結下我的體驗,直接用代碼
(1)torch.utils.data.Dataset
(2)torch.utils.data.DataLoader
這兩個類搭配的數據讀取代碼:

import os
import glob
import cv2
import numpy as np
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader


#第一種數據讀取方式
transform = T.Compose([
    T.Resize(224),
    T.CenterCrop(224),
    T.RandomHorizontalFlip(),
    T.RandomSizedCrop(224),
    T.ToTensor(),#將圖片從0-255變爲0-1
    T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])#標準化到[-1,1]
])

class Test_Data(Dataset):
    def __init__(self,data_root,mask_root,transforms=None):

        data_image = glob.glob(data_root+'/*.jpg')
        self.data_image = data_image

        mask_image = glob.glob(mask_root+'/*.jpg')
        self.mask_image = mask_image

        self.transforms = transforms

    def __getitem__(self, index):

        data_image_path = self.data_image[index]
        mask_image_path = self.mask_image[index]

        image_data = cv2.imread(data_image_path,-1)
        mask_data = cv2.imread(mask_image_path,-1)
        if self.transforms:
            image_data = self.transforms(image_data)
            mask_data = self.transforms(mask_data)

        return image_data,mask_data

    def __len__(self):
        return len(self.data_image)

dataset = Test_Data(data_root='../../test_image/37_simple/0001/data',mask_root='../../test_image/37_simple/0001/mask')

#第一種調用,不常用
for data,mask in dataset:
     print(data.shape,mask.shape)

下面兩種應該在訓練過程中更加好:
test_data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=1, pin_memory=True, shuffle=True,drop_last=True)
#第一種
# for i,data in enumerate(test_data_loader,0):
#     print(data[0].shape,'..')
#     print(data[1].shape,'...')
#第2種
for data_batch,mask_batch in test_data_loader:
    print(data_batch.size(),mask_batch.size())

還有一種通過from torchvision.datasets import ImageFolder來訪問文件數據
不過,我覺得這種更加適合在分類任務中應用

from torchvision.datasets import ImageFolder
#
dataset_data = ImageFolder('../../test_image/37_simple/0001/',transform=None)
#
print(dataset_data.class_to_idx)
print(dataset_data.img)

這種還沒有用過

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章