版权归属:
更多关注:
1 问题
PyTorch实现image-to-image时,需要对image pair进行相同参数的transform(i.e. sync),torchvision已经实现了多种针对single image的transform,但是在使用时input image与target image的transform不一致,上网查阅没有发现好的方法,因此只能自己实现。查看了TorchVision的Transform代码,自己也是可以仿照实现。
2 Random Scale and Crop原理
将原始图片放大scale=r倍,为了满足crop后的图片中间的蓝色区域没有被分割遗漏,就需要满足crop size > min crop size,进而得到r > 2n/(m+n)。
如果将n=256, m=172带入,r大概1.2左右。反过来如果令r在[1.0, 1.2]之间那么m=170.6,实际中取m=170 or 172。
3 实现代码
Talk is cheap, show me the code.
#coding=utf-8
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import random
import numpy as np
from PIL import Image
import numbers
import collections
import sys
if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
def get_sync_transform(opt):
transform_list = []
osize = [opt.loadSize, opt.loadSize]
transform_list.append(Resize(osize, Image.BILINEAR))
transform_list.append(RandomScaleCrop())
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
transform_list += [ToTensor(),
Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
class RandomScaleCrop(object):
"""Crop the given PIL Image at a random location.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
padding (int or sequence, optional): Optional padding on each border
of the image. Default is None, i.e no padding. If a sequence of length
4 is provided, it is used to pad left, top, right, bottom borders
respectively. If a sequence of length 2 is provided, it is used to
pad left/right, top/bottom borders, respectively.
pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception. Since cropping is done
after padding, the padding seems to be done at a random offset.
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value on the edge of the image
- reflect: pads with reflection of image (without repeating the last value on the edge)
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image (repeating the last value on the edge)
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
"""
def __init__(self, min_scale=1.0, max_scale=1.2):
self.min_scale = min_scale
self.max_scale = max_scale
@staticmethod
def get_params(img, min_scale, max_scale):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = img.size
scale = random.uniform(min_scale, max_scale)
nw = int(scale * w) + 1 # new width
nh = int(scale * h) + 1 # new height
shift_x = int(np.ceil(np.random.uniform(0.01, nw - w)))
shift_y = int(np.ceil(np.random.uniform(0.01, nh - h)))
return shift_y, shift_x, nh, nw
def __call__(self, images):
"""
Args:
img (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image.
"""
th, tw = images[0].size # target height, width
i, j, nh, nw = self.get_params(images[0], self.min_scale, self.max_scale)
imgs = [F.resize(img, (nw, nh), Image.BILINEAR) for img in images]
return [F.crop(img, i, j, th, tw) for img in imgs]
def __repr__(self):
return self.__class__.__name__ + '(size={0}, min_scale={1}, max_scale={2})'.format(self.size, self.min_scale, self.max_scale)
class ToTensor(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
"""
def __call__(self, pics):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return [F.to_tensor(pic) for pic in pics]
def __repr__(self):
return self.__class__.__name__ + '()'
class Normalize(object):
"""Normalize a tensor image with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
will normalize each channel of the input ``torch.*Tensor`` i.e.
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts out of place, i.e., it does not mutates the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, tensors):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors]
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class Resize(object):
"""Resize the input PIL Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
def __call__(self, images):
"""
Args:
img (PIL Image): Image to be scaled.
Returns:
PIL Image: Rescaled image.
"""
return [F.resize(img, self.size, self.interpolation) for img in images]
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)