from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import pandas as pd
import cv2
import os
import glob
class Mydataset(Dataset):
def __init__(self, data_dir, cls_list, transform=None, suffix="*.jpg"):
super().__init__()
self.data_dir = data_dir
file_paths = []
labels = []
for index, cls_name in enumerate(cls_list):
file_list = glob.glob(os.path.join(data_dir, cls_name, suffix))
if file_list:
file_paths.extend(file_list)
labels.extend([index for i in file_list])
self.df = pd.DataFrame({
"file_paths": file_paths,
"labels": labels
},
dtype='object').values
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idex):
img_name, label = self.df[idex]
image = cv2.imread(img_name)
if image.shape[2] == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if self.transform is not None:
image = self.transform(image)
return image, label
transforms_train = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(), # 水平翻轉
transforms.RandomRotation(10), # 隨機旋轉10度
transforms.ToTensor(), # 將數據轉換成Tensor型
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
if __name__ == "__main__":
train_path = '/Users/goby/data/tianshi_image/train_img'
train_data = Mydataset(train_path, ["金", "木", "水", "火", "土"], transform=transforms_train)
BATCH_SIZE = 64
dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False)
for i_batch, sample_batched in enumerate(dataloader, 0):
print(i_batch)