PyTorch實現數據增強(kaggle環境)

一、數據增強方法:

1. 對圖片進行比例縮放

2. 對圖片進行隨機位置的截取

3. 對圖片進行隨機水平和豎直翻轉

4. 對圖片進行隨機角度的旋轉

5. 對圖片進行亮度、對比度和顏色隨機變化

二、Torch中已經把這些方法內置在了torchvision中,可以直接調用

from PIL import Image
from torchvision import transforms as tfs
im = Image.open('../input/cat.jpg')
im

在這裏插入圖片描述

1. 隨機比例縮放

使用:torchvision.transforms.Resize(),參數1表示縮放圖片大小,可以爲tuple,參數2表示縮放方法,默認爲雙線性插值

print('before scale, shape: {}'.format(im.size))
new_im = tfs.Resize((100,200))(im)
print('after scale, shape: {}'.format(new_im.size))
new_im
before scale, shape: (121, 121)
after scale, shape: (200, 100)

在這裏插入圖片描述

2.隨機位置截取

使用:

(1)torchvision.transforms.RandomCrop(),參數爲截取圖片的大小

(2)torchvision.transforms.CenterCrop(),參數爲截取圖片的大小,但以原始圖片的中心爲中心

# 隨機裁剪
random_im = tfs.RandomCrop((60, 60))(im)
random_im

在這裏插入圖片描述

# 中心裁剪
center_im = tfs.CenterCrop((60, 60))(im)
center_im

在這裏插入圖片描述

3.隨機水平和豎直翻轉

使用torchvision.transforms.RandomHorizontalFlip() 和torchvision.transforms.RandomVerticalFlip()

無參數

horizontal_im = tfs.RandomHorizontalFlip()(im)
horizontal_im

在這裏插入圖片描述

vertical_im = tfs.RandomVerticalFlip()(im)
vertical_im

在這裏插入圖片描述

4.隨機角度旋轉

使用:torchvision.transforms.RandomRotation()。參數爲旋轉的角度。比如20,則會隨機在-20~20之間進行旋轉

rotation_im = tfs.RandomRotation(45)(im)
rotation_im

在這裏插入圖片描述

5. 亮度、對比度和顏色

torchvision.transforms.ColorJitter()。參數1爲亮度,參數2爲對比度,參數3爲飽和度

bright_im = tfs.ColorJitter(brightness=1)(im)  #隨機在0~2之間變化,1
bright_im

在這裏插入圖片描述

contrast_im = tfs.ColorJitter(contrast=1)(im) #隨機在0~2之間變化,
contrast_im

在這裏插入圖片描述

color_im = tfs.ColorJitter(hue=0.5)(im) # 隨機從 -0.5 ~ 0.5 之間變化
color_im

在這裏插入圖片描述

三、聯合使用數據增強方法

使用 torchvision.transforms.Compose()。利用List進行組裝,然後傳遞給Compose

im_aug = tfs.Compose([
    tfs.Resize(120),
    tfs.RandomHorizontalFlip(),
    tfs.RandomCrop(96),
    tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)
])
import matplotlib.pyplot as plt
%matplotlib inline
nrows = 3
ncols = 3
figsize = (8, 8)
fig, figs = plt.subplots(nrows, ncols, figsize=figsize)
for i in range(nrows):
    for j in range(ncols):
        figs[i][j].imshow(im_aug(im))
        figs[i][j].axes.get_xaxis().set_visible(False)
        figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()

在這裏插入圖片描述

四、使用數據增強

(1)訓練集採用數據增強

def train_tf(x):
    im_aug = tfs.Compose([
        tfs.Resize(120),
        tfs.RandomHorizontalFlip(),
        tfs.RandomCrop(96),
        tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),tfs.ToTensor(),
        tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    x = im_aug(x)
    return x

(2)測試集不採用數據增強

def test_tf(x):
    im_aug = tfs.Compose([
        tfs.Resize(96),
        tfs.ToTensor(),
        tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
     ])
    x = im_aug(x)
    return x

訓練過程省略。。。。。。

通過數據增強,訓練集準確率會下降,因爲數據特徵變的多樣性,更加難訓練。

通過數據增強,測試集準確率會上升,所以模型的泛化能力提高了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章