使用pytorch 的torch.utils.Dataset類編寫自己的數據集類

原文鏈接:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

導入必要的庫

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()

用pandas讀入數據


landmarks_frame = pd.read_csv(r'../data/faces/face_landmarks.csv')
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 landmarks: {}'.format(landmarks[:4]))

Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 landmarks: [[32. 65.]
 [33. 76.]
 [34. 86.]
 [34. 97.]]

定義一個顯示圖片和landmarks的函數

# 顯示圖片和landmarks的函數
def show_landmarks(image, landmarks):
    """ show image with landmarks"""
    
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)
    
plt.figure()
show_landmarks(io.imread(os.path.join('../data/faces', img_name)), landmarks)
plt.show()

在這裏插入圖片描述

定義一個Dataset類, 繼承torch.utils.Dataset類

torch.utils.data.Dataset 是一個抽象類, 表示一個dataset.
自定義的dataset類需要繼承Dataset. 並且重載:

  • __len__函數, len(dataset)返回數據集的長度
  • __getitem__函數, 支持dataset[i]尋址.
class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""
    def __init__(self, csv_file, root_dir, transform=None):
        """
        :param csv_file: csv文件的路徑
        :param root_dir: 圖像的文件夾路徑
        :param transform: 可選的transform
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.landmarks_frame)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample
# 初始化此類
face_dataset = FaceLandmarksDataset(csv_file=r'../data/faces/face_landmarks.csv',
                                    root_dir='../data/faces/')
# 繪製前4個圖.
fig = plt.figure()
for i in range(len(face_dataset)):
    sample = face_dataset[i]
    print(i, sample['image'].shape, sample['landmarks'].shape)
    ax = plt.subplot(1, 4, i+1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)
    
    if i == 3:
        plt.show()
        break
0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

在這裏插入圖片描述在這裏插入圖片描述在這裏插入圖片描述在這裏插入圖片描述

自定義圖像的變換

圖像縮放

#%%
# 自定義transforms
class Rescale(object):
    """圖像縮放"""
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size
    
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            # 如果目標尺寸只有一個值, 那麼按照最小邊縮放到此值.
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size
        
        new_h, new_w = int(new_h), int(new_w)
        img = transform.resize(image, (new_h, new_w))
        # 對landmarks做縮放變換. landmarks的x值是橫座標, y是縱座標.
        landmarks = landmarks * [new_w / w, new_h / h]
        
        return {'image': img, 'landmarks': landmarks}

隨機裁剪

#%%
class RandomCrop(object):
    """隨機裁剪"""
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size
        
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        h, w = image.shape[:2]
        new_h, new_w = self.output_size
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        image = image[top: top + new_h, left: left + new_w]
        # 注意landmarks裁剪後可能是負值
        landmarks = landmarks - [left, top]
        return {'image': image, 'landmarks': landmarks}

將numpy的ndarrays轉換爲 Tensor

class ToTensor(object):
    """將numpy的ndarrays轉換爲 Tensor"""
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        # 將矩陣轉換爲 channel * height * width
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image), 
                'landmarks': torch.from_numpy(landmarks)}
    

將transform組合

# 將transform組合
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256), 
                               RandomCrop(224)])
fig = plt.figure()
sample = face_dataset[65]
print(type(sample))
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)
    ax = plt.subplot(1, 3, i+1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()

在這裏插入圖片描述在這裏插入圖片描述在這裏插入圖片描述

使用Dataloader遍歷自定義的dataset

# 遍歷dataset
transfromed_dataset = FaceLandmarksDataset(csv_file='../data/faces/face_landmarks.csv',
                                           root_dir='../data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))
for i in range(len(transfromed_dataset)):
    sample = transfromed_dataset[i]
    print(i, sample['image'].size(), sample['landmarks'].size())
    if i == 3:
        break

# BrokenPipeError,則將num_workers 設置爲0.
dataloader = DataLoader(transfromed_dataset, batch_size=4, 
                        shuffle=True, num_workers=0)
# 顯示一個batch的函數
def show_landmarks_batch(sample_batched):
    """show image with landmarks for a batch of samples"""
    images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2
    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    
    for i in range(batch_size):
        plt.scatter(landmarks_batch[i,:,0].numpy() + i*im_size +(i+1)*grid_border_size,
                    landmarks_batch[i,:,1].numpy() + grid_border_size, 
                    s = 10, marker='.', c='r')
        plt.title('Batch from dataloader')

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break
0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章