PyTorch自定義Transform模塊(torchvision)-深度學習

版權歸屬:

更多關注:

1 問題

PyTorch實現image-to-image時,需要對image pair進行相同參數的transform(i.e. sync),torchvision已經實現了多種針對single image的transform,但是在使用時input image與target image的transform不一致,上網查閱沒有發現好的方法,因此只能自己實現。查看了TorchVision的Transform代碼,自己也是可以仿照實現。

2 Random Scale and Crop原理

在這裏插入圖片描述
將原始圖片放大scale=r倍,爲了滿足crop後的圖片中間的藍色區域沒有被分割遺漏,就需要滿足crop size > min crop size,進而得到r > 2n/(m+n)。
如果將n=256, m=172帶入,r大概1.2左右。反過來如果令r在[1.0, 1.2]之間那麼m=170.6,實際中取m=170 or 172。

3 實現代碼

Talk is cheap, show me the code.

#coding=utf-8
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import random
import numpy as np
from PIL import Image
import numbers
import collections
import sys

if sys.version_info < (3, 3):
    Sequence = collections.Sequence
    Iterable = collections.Iterable
else:
    Sequence = collections.abc.Sequence
    Iterable = collections.abc.Iterable



def get_sync_transform(opt):
    transform_list = []

    osize = [opt.loadSize, opt.loadSize]
    transform_list.append(Resize(osize, Image.BILINEAR))
    transform_list.append(RandomScaleCrop())

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.RandomHorizontalFlip())

    transform_list += [ToTensor(),
                       Normalize((0.5, 0.5, 0.5),
                                 (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)


class RandomScaleCrop(object):
    """Crop the given PIL Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is None, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively. If a sequence of length 2 is provided, it is used to
            pad left/right, top/bottom borders, respectively.
        pad_if_needed (boolean): It will pad the image if smaller than the
            desired size to avoid raising an exception. Since cropping is done
            after padding, the padding seems to be done at a random offset.
        fill: Pixel fill value for constant fill. Default is 0. If a tuple of
            length 3, it is used to fill R, G, B channels respectively.
            This value is only used when the padding_mode is constant
        padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.

             - constant: pads with a constant value, this value is specified with fill

             - edge: pads with the last value on the edge of the image

             - reflect: pads with reflection of image (without repeating the last value on the edge)

                padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
                will result in [3, 2, 1, 2, 3, 4, 3, 2]

             - symmetric: pads with reflection of image (repeating the last value on the edge)

                padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
                will result in [2, 1, 1, 2, 3, 4, 4, 3]

    """

    def __init__(self, min_scale=1.0, max_scale=1.2):
        self.min_scale = min_scale
        self.max_scale = max_scale

    @staticmethod
    def get_params(img, min_scale, max_scale):
        """Get parameters for ``crop`` for a random crop.

        Args:
            img (PIL Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
        w, h   = img.size

        scale = random.uniform(min_scale, max_scale)

        nw = int(scale * w) + 1 # new width
        nh = int(scale * h) + 1 # new height
        shift_x = int(np.ceil(np.random.uniform(0.01, nw - w)))
        shift_y = int(np.ceil(np.random.uniform(0.01, nh - h)))

        return shift_y, shift_x, nh, nw

    def __call__(self, images):
        """
        Args:
            img (PIL Image): Image to be cropped.

        Returns:
            PIL Image: Cropped image.
        """
        th, tw       = images[0].size # target height, width
        i, j, nh, nw = self.get_params(images[0], self.min_scale, self.max_scale)
        imgs         = [F.resize(img, (nw, nh), Image.BILINEAR) for img in images]

        return [F.crop(img, i, j, th, tw) for img in imgs]

    def __repr__(self):
        return self.__class__.__name__ + '(size={0}, min_scale={1}, max_scale={2})'.format(self.size, self.min_scale, self.max_scale)


class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8
    In the other cases, tensors are returned without scaling.
    """

    def __call__(self, pics):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            Tensor: Converted image.
        """
        return [F.to_tensor(pic) for pic in pics]

    def __repr__(self):
        return self.__class__.__name__ + '()'


class Normalize(object):
    """Normalize a tensor image with mean and standard deviation.
    Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
    will normalize each channel of the input ``torch.*Tensor`` i.e.
    ``input[channel] = (input[channel] - mean[channel]) / std[channel]``

    .. note::
        This transform acts out of place, i.e., it does not mutates the input tensor.

    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
    """

    def __init__(self, mean, std, inplace=False):
        self.mean = mean
        self.std = std
        self.inplace = inplace

    def __call__(self, tensors):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized Tensor image.
        """
        return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors]


    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


class Resize(object):
    """Resize the input PIL Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, images):
        """
        Args:
            img (PIL Image): Image to be scaled.

        Returns:
            PIL Image: Rescaled image.
        """
        return [F.resize(img, self.size, self.interpolation) for img in images]

    def __repr__(self):
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章