PyTorch框架中有一個很常用的包:torchvision
torchvision主要由3個子包構成:torchvision.datasets
、torchvision.models
、torchvision.transforms
詳細內容可參考:http://pytorch.org/docs/master/torchvision/index.html
GitHub:https://github.com/pytorch/vision/tree/master/torchvision。
這篇主要介紹torchvision.transformas,基本上PyTorch中的resize、crop、normalize等常見的數據預處理及數據增強(data augmentation)操作都可以通過該接口實現。
torchvision.transformas主要涉及兩個文件:transformas.py
和functional.py
,在transformas.py
中定義了各種data augmentation的類,在每個類中通過調用functional.py中對應的函數完成data augmentation操作。
$ vim /home/lwp/.local/lib/python2.7/site-packages/torchvision/transforms/transforms.py
使用示例,這是Re-ID MGN模型實現代碼中的一段,https://github.com/lwplw/re-id_mgn/blob/master/pytorch_MGN/data/init.py,用到了Resize
、RandomHorizontalFlip
、ToTensor
、Normalize
:
from importlib import import_module
from torchvision import transforms
from utils.random_erasing import RandomErasing
from data.sampler import RandomSampler
from torch.utils.data import dataloader
class Data:
def __init__(self, args):
train_list = [
transforms.Resize((args.height, args.width), interpolation=3),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
if args.random_erasing:
train_list.append(RandomErasing(probability=args.probability, mean=[0.0, 0.0, 0.0]))
train_transform = transforms.Compose(train_list)
test_transform = transforms.Compose([
transforms.Resize((args.height, args.width), interpolation=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
if not args.test_only:
module_train = import_module('data.' + args.data_train.lower())
self.trainset = getattr(module_train, args.data_train)(args, train_transform, 'train')
self.train_loader = dataloader.DataLoader(self.trainset,
sampler=RandomSampler(self.trainset,args.batchid,batch_image=args.batchimage),
#shuffle=True,
batch_size=args.batchid * args.batchimage,
num_workers=args.nThread)
else:
self.train_loader = None
if args.data_test in ['Market1501']:
module = import_module('data.' + args.data_train.lower())
self.testset = getattr(module, args.data_test)(args, test_transform, 'test')
self.queryset = getattr(module, args.data_test)(args, test_transform, 'query')
else:
raise Exception()
self.test_loader = dataloader.DataLoader(self.testset, batch_size=args.batchtest, num_workers=args.nThread)
self.query_loader = dataloader.DataLoader(self.queryset, batch_size=args.batchtest, num_workers=args.nThread)
各種操作的類定義在transformas.py
文件中:
from.import functional as F
,導入了functional.py
中具體的data augmentation函數;__all__
列表定義了可以從外部import的函數名或類名。
from __future__ import division
import torch
import math
import random
from PIL import Image, ImageOps, ImageEnhance
try:
import accimage
except ImportError:
accimage = None
import numpy as np
import numbers
import types
import collections
import warnings
from . import functional as F
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
}
Compose()
用來管理各個transform,其中__call__
方法就是對輸入img遍歷所有的transform操作。
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
ToTensor()
Convert a PIL Image
or numpy.ndarray
to tensor.
在做數據歸一化之前必須要把PIL Image
轉成Tensor
,其它resize或crop操作不需要。
class ToTensor(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
ToPILImage()
Convert a tensor
or an ndarray
to PIL Image
.
ToTensor()
的反向操作。
Normalize()
數據歸一化處理,調用前數據需處理成Tensor
。
class Normalize(object):
"""Normalize a tensor image with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
will normalize each channel of the input ``torch.*Tensor`` i.e.
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(tensor, self.mean, self.std)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
Resize()
對PIL Image
實現resize
操作。
- 如果輸入爲單個
int
值,則將輸入圖像的短邊resize到這個int數,長邊根據對應比例調整,圖像長寬比保持不變。 - 如果輸入爲
(h,w)
,且h、w爲int,則直接將輸入圖像resize到(h,w)尺寸,圖像的長寬比可能會發生變化
在__call__
方法中調用了functional.py
腳本中的resize函數來完成resize操作。因爲輸入是PIL Image
,所以resize函數基本是在調用Image的各種方法。
class Resize(object):
"""Resize the input PIL Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be scaled.
Returns:
PIL Image: Rescaled image.
"""
return F.resize(img, self.size, self.interpolation)
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
CenterCrop()
以輸入圖像img的中心作爲中心點進行指定size的crop操作,在數據增強中一版不會去使用該方法。因爲當size固定時,對於同一張img,N次CenterCrop的結果是一樣的。
size可以給單個int
值,也可以給(int(size), int(size))
class CenterCrop(object):
"""Crops the given PIL Image at the center.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image.
"""
return F.center_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
RandomCrop()
RandomCrop相比前面的CenterCrop要更加常用一些,兩者的區別在於RandomCrop在crop時的中心點座標是隨機的,不再是輸入圖像的中心座標,因此基本上每次crop生成的圖像都是不同的。
class RandomCrop(object):
"""Crop the given PIL Image at a random location.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
padding (int or sequence, optional): Optional padding on each border
of the image. Default is 0, i.e no padding. If a sequence of length
4 is provided, it is used to pad left, top, right, bottom borders
respectively.
pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception.
"""
def __init__(self, size, padding=0, pad_if_needed=False):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
@staticmethod
def get_params(img, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = img.size
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image.
"""
if self.padding > 0:
img = F.pad(img, self.padding)
# pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]:
img = F.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))
# pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]:
img = F.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))
i, j, h, w = self.get_params(img, self.size)
return F.crop(img, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
RandomHorizontalFlip()
圖像隨機水平翻轉,翻轉概率爲0.5
。
class RandomHorizontalFlip(object):
"""Horizontally flip the given PIL Image randomly with a given probability.
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < self.p:
return F.hflip(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
RandomVerticalFlip()
圖像隨機垂直翻轉
class RandomVerticalFlip(object):
"""Vertically flip the given PIL Image randomly with a given probability.
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < self.p:
return F.vflip(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
RandomResizedCrop()
CenterCrop
和RandomCrop
在crop時是固定size
,RandomResizedCrop
則是random size
的crop。
該類源碼需要3個參數:size
、scale
和ratio
,這裏我在使用中將接口中size
修改成了size_h, size_w
。方法爲先crop,再resize到指定尺寸。
crop時,其中心點座標和寬高是由get_params
方法得到的,首先在scale
限定的數值範圍內隨機生成一個數,用這個數乘以輸入圖像的面積作爲crop後圖像的面積,然後在ratio
限定的數值範圍內隨機生成一個數,表示寬高比,根據這兩個值就可以得到crop圖像的寬高。crop圖像的中心點座標,是類RandomCrop類一樣是隨機生成的。
class RandomResizedCrop(object):
"""Crop the given PIL Image to random size and aspect ratio.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size_h, size_w, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
self.size = (size_h, size_w)
self.interpolation = interpolation
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(*scale) * area
aspect_ratio = random.uniform(*ratio)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback
w = min(img.size[0], img.size[1])
i = (img.size[1] - w) // 2
j = (img.size[0] - w) // 2
return i, j, w, w
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string