用pytorch讀取數據,確實要比tensorflow簡單,但是也得熟悉半個小時左右.
下面總結下我的體驗,直接用代碼
(1)torch.utils.data.Dataset
(2)torch.utils.data.DataLoader
這兩個類搭配的數據讀取代碼:
import os
import glob
import cv2
import numpy as np
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
#第一種數據讀取方式
transform = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.RandomHorizontalFlip(),
T.RandomSizedCrop(224),
T.ToTensor(),#將圖片從0-255變爲0-1
T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])#標準化到[-1,1]
])
class Test_Data(Dataset):
def __init__(self,data_root,mask_root,transforms=None):
data_image = glob.glob(data_root+'/*.jpg')
self.data_image = data_image
mask_image = glob.glob(mask_root+'/*.jpg')
self.mask_image = mask_image
self.transforms = transforms
def __getitem__(self, index):
data_image_path = self.data_image[index]
mask_image_path = self.mask_image[index]
image_data = cv2.imread(data_image_path,-1)
mask_data = cv2.imread(mask_image_path,-1)
if self.transforms:
image_data = self.transforms(image_data)
mask_data = self.transforms(mask_data)
return image_data,mask_data
def __len__(self):
return len(self.data_image)
dataset = Test_Data(data_root='../../test_image/37_simple/0001/data',mask_root='../../test_image/37_simple/0001/mask')
#第一種調用,不常用
for data,mask in dataset:
print(data.shape,mask.shape)
下面兩種應該在訓練過程中更加好:
test_data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=1, pin_memory=True, shuffle=True,drop_last=True)
#第一種
# for i,data in enumerate(test_data_loader,0):
# print(data[0].shape,'..')
# print(data[1].shape,'...')
#第2種
for data_batch,mask_batch in test_data_loader:
print(data_batch.size(),mask_batch.size())
還有一種通過from torchvision.datasets import ImageFolder來訪問文件數據
不過,我覺得這種更加適合在分類任務中應用
from torchvision.datasets import ImageFolder
#
dataset_data = ImageFolder('../../test_image/37_simple/0001/',transform=None)
#
print(dataset_data.class_to_idx)
print(dataset_data.img)
這種還沒有用過