Pytorch學習(三)--- 使用torchvision.transforms快速對圖像數據做數據增強

在深度學習任務中,通常讀入數據後,我們都需要對數據做transform操作,最後纔將transform後的數據送入模型進行訓練測試。
一個完整數據流pipeline可以定義爲如下:

讀取數據 -> transform -> 模型

本文學習pipeline中的transform部分(torchvision.transforms)。
torchvision.transforms是torchvision中的一個用於數據增強的包,包含了很多transform操作。

torchvision.transforms.Compose(transforms)

作用:將多個transform組合起來使用。
其源碼如下:

class Compose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

可以看到主要的__call__方法就是對輸入圖像img循環所有的transform操作。

使用例子:


train_transform = transforms.Compose([
                            # transforms.RandomGrayscale(),
                            transforms.Resize((512, 512)),
                            transforms.RandomAffine(5),
                            # transforms.ColorJitter(hue=.05, saturation=.05),
                            # transforms.RandomCrop((88, 88)),
                            transforms.RandomHorizontalFlip(),
                            transforms.RandomVerticalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])




class QRDataset(Dataset):
    def __init__(self, img_df, transform=None):
        self.img_df = img_df
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None
    
    def __getitem__(self, index):
        start_time = time.time()
        img = Image.open(self.img_df.iloc[index]['id']).convert('RGB')
        
        if self.transform is not None:
            img = self.transform(img)
        return img,torch.from_numpy(np.array(self.img_df.iloc[index]['label']))
    
    def __len__(self):
        return len(self.img_df)



train_loader = torch.utils.data.DataLoader(
        QRDataset(train_jpg.iloc[train_idx],
				train_transform,
        ), batch_size=10, shuffle=True, num_workers=20, pin_memory=True
    )
 ])

這裏定義了ResizeRandomAffineRandomHorizontalFlip等數據預處理操作,並最終作爲數據讀取類QRDataset的一個參數傳入,可以在內部方法__getitem__中實現數據增強操作。

torchvision.transforms.CenterCrop(size)

作用:將給定的PIL.Image進行中心切割,得到給定的sizesize可以是tuple(target_height, target_width)size也可以是一個Integer,在這種情況下,切出來的圖片的形狀是正方形。

torchvision.transforms.RandomCrop(size, padding=0)

作用:切割中心點的位置隨機選取。size可以是tuple也可以是Integer

torchvision.transforms.RandomHorizontalFlip

作用:隨機水平翻轉給定的PIL.Image,概率爲0.5。即:一半的概率翻轉,一半的概率不翻轉。

torchvision.transforms.RandomSizedCrop(size, interpolation=2)

作用:先將給定的PIL.Image隨機切,然後再resize成給定的size大小

torchvision.transforms.Pad(padding, fill=0)

作用:將給定的PIL.Image的所有邊用給定的pad value填充。 padding:要填充多少像素 fill:用什麼值
例子:

from torchvision import transforms
from PIL import Image
padding_img = transforms.Pad(padding=10, fill=0)
img = Image.open('test.jpg')

print(type(img))
print(img.size)

padded_img=padding(img)
print(type(padded_img))
print(padded_img.size)
<class 'PIL.PngImagePlugin.PngImageFile'>
(10, 10)
<class 'PIL.Image.Image'>
(30, 30) #由於上下左右都要填充10個像素,所以填充後的size是(30,30)
torchvision.transforms.ToTensor

作用:把一個取值範圍是[0,255]PIL.Image或者shape爲(H,W,C)numpy.ndarray,轉換成形狀爲[C,H,W],取值範圍是[0,1.0]torch.FloadTensor

class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'

在PyTorch中常用PIL庫來讀取圖像數據,因此這個方法相當於搭建了PIL.Image和Tensor的橋樑。另外要強調的是在做數據歸一化之前必須要把PIL.Image轉成Tensor

torchvision.transforms.Normalize(mean, std)

作用:歸一化操作。
給定均值:(R,G,B) 方差:(R,G,B),將會把Tensor正則化。

class Normalize(object):
    """Normalize a tensor image with mean and standard deviation.
    Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
    will normalize each channel of the input ``torch.*Tensor`` i.e.
    ``input[channel] = (input[channel] - mean[channel]) / std[channel]``

    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
    """

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized Tensor image.
        """
        return F.normalize(tensor, self.mean, self.std)

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

在深度學習分類檢測任務中,常用的是

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

這幾個值是根據imagnet數據集計算得到的均值、方差。

torchvision.transforms.ToPILImage

作用:將shape(C,H,W)Tensorshape(H,W,C)numpy.ndarray轉換成PIL.Image,值不變。

參考

https://www.jianshu.com/p/1ae863c1e66d
https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-transform/

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章