pytorch的語義分割------數據增廣

官方文檔:https://pytorch.org/docs/stable/torchvision/transforms.html?highlight=torchvision%20transforms%20functional#module-torchvision.transforms.functional

語義分割的labe與分類問題不同,語義分割的label是一個mask,所以訓練圖像在做增廣的時候,mask也要做相應的變換。

如:原來的圖片和mask如下:

做旋轉後:

做翻轉後:

所以對於語義分割而言,其數據增廣是連同mask一起做的。

---------------------------------------------------------------------------------------------------

以下提供一份代碼,集成了 torchvision.transforms.functional 裏的比較常用的數據增廣方法(包括旋轉,上下翻轉,左右翻轉,裁剪,調整對比度,調整飽和度,調整亮度,中心裁剪等):

此代碼用於在訓練前通過對原來的數據做多種轉換生產更多的訓練數據。

此代碼完全可以直接調用,用一個.py文件存放即可。

代碼:

import random
import os
import numpy as np
from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as tf

class Augmentation:
    def __init__(self):
        pass
    def rotate(self,image,mask,angle=None):
        if angle == None:
            angle = transforms.RandomRotation.get_params([-180, 180]) # -180~180隨機選一個角度旋轉
        if isinstance(angle,list):
            angle = random.choice(angle)
        image = image.rotate(angle)
        mask = mask.rotate(angle)
        image = tf.to_tensor(image)
        mask = tf.to_tensor(mask)
        return image, mask
    def flip(self,image,mask): #水平翻轉和垂直翻轉
        if random.random()>0.5:
            image = tf.hflip(image)
            mask = tf.hflip(mask)
        if random.random()<0.5:
            image = tf.vflip(image)
            mask = tf.vflip(mask)
        image = tf.to_tensor(image)
        mask = tf.to_tensor(mask)
        return image, mask
    def randomResizeCrop(self,image,mask,scale=(0.3,1.0),ratio=(1,1)):#scale表示隨機crop出來的圖片會在的0.3倍至1倍之間,ratio表示長寬比
        img = np.array(image)
        h_image, w_image = img.shape
        resize_size = h_image
        i, j, h, w = transforms.RandomResizedCrop.get_params(image, scale=scale, ratio=ratio)
        image = tf.resized_crop(image, i, j, h, w, resize_size)
        mask = tf.resized_crop(mask, i, j, h, w, resize_size)
        image = tf.to_tensor(image)
        mask = tf.to_tensor(mask)
        return image, mask
    def adjustContrast(self,image,mask):
        factor = transforms.RandomRotation.get_params([0,10])   #這裏調增廣後的數據的對比度
        image = tf.adjust_contrast(image,factor)
        #mask = tf.adjust_contrast(mask,factor)
        image = tf.to_tensor(image)
        mask = tf.to_tensor(mask)
        return image,mask
    def adjustBrightness(self,image,mask):
        factor = transforms.RandomRotation.get_params([1, 2])  #這裏調增廣後的數據亮度
        image = tf.adjust_brightness(image, factor)
        #mask = tf.adjust_contrast(mask, factor)
        image = tf.to_tensor(image)
        mask = tf.to_tensor(mask)
        return image, mask
    def centerCrop(self,image,mask,size=None): #中心裁剪
        if size == None:size = image.size   #若不設定size,則是原圖。
        image = tf.center_crop(image,size)
        mask = tf.center_crop(mask,size)
        image = tf.to_tensor(image)
        mask = tf.to_tensor(mask)
        return image,mask
    def adjustSaturation(self,image,mask): #調整飽和度
        factor = transforms.RandomRotation.get_params([1, 2])  # 這裏調增廣後的數據亮度
        image = tf.adjust_saturation(image, factor)
        #mask = tf.adjust_saturation(mask, factor)
        image = tf.to_tensor(image)
        mask = tf.to_tensor(mask)
        return image, mask


def augmentationData(image_path,mask_path,option=[1,2,3,4,5,6,7],save_dir=None):
    '''
    :param image_path: 圖片的路徑
    :param mask_path: mask的路徑
    :param option: 需要哪種增廣方式:1爲旋轉,2爲翻轉,3爲隨機裁剪並恢復原本大小,4爲調整對比度,5爲中心裁剪(不恢復原本大小),6爲調整亮度,7爲飽和度
    :param save_dir: 增廣後的數據存放的路徑
    '''
    aug_image_savedDir = os.path.join(save_dir,'img')
    aug_mask_savedDir = os.path.join(save_dir, 'mask')
    if not os.path.exists(aug_image_savedDir):
        os.makedirs(aug_image_savedDir)
        print('create aug image dir.....')
    if not os.path.exists(aug_mask_savedDir):
        os.makedirs(aug_mask_savedDir)
        print('create aug mask dir.....')
    aug = Augmentation()
    res= os.walk(image_path)
    images = []
    masks = []
    for root,dirs,files in res:
        for f in files:
            images.append(os.path.join(root,f))
    res = os.walk(mask_path)
    for root,dirs,files in res:
        for f in files:
            masks.append(os.path.join(root,f))
    datas = list(zip(images,masks))
    num = len(datas)

    for (image_path,mask_path) in datas:
        image = Image.open(image_path)
        mask = Image.open(mask_path)
        if 1 in option:
            num+=1
            image_tensor, mask_tensor = aug.rotate(image, mask)
            image_rotate = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'img', str(num) + '__rotate.jpg'))
            mask_rotate = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'mask', str(num) + '_rotate_mask.jpg'))
        if 2 in option:
            num+=1
            image_tensor, mask_tensor = aug.flip(image, mask)
            image_filp = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir,'img',str(num)+'_filp.jpg'))
            mask_filp = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir,'mask',str(num)+'_filp_mask.jpg'))
        if 3 in option:
            num+=1
            image_tensor, mask_tensor = aug.randomResizeCrop(image, mask)
            image_ResizeCrop = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'img', str(num) + '_ResizeCrop.jpg'))
            mask_ResizeCrop = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'mask', str(num) + '_ResizeCrop_mask.jpg'))
        if 4 in option:
            num+=1
            image_tensor, mask_tensor = aug.adjustContrast(image, mask)
            image_Contrast = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'img', str(num) + '_Contrast.jpg'))
            mask_Contrast = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'mask', str(num) + '_Contrast_mask.jpg'))
        if 5 in option:
            num+=1
            image_tensor, mask_tensor = aug.centerCrop(image, mask)
            image_centerCrop = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'img', str(num) + '_centerCrop.jpg'))
            mask_centerCrop = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'mask', str(num) + '_centerCrop_mask.jpg'))
        if 6 in option:
            num+=1
            image_tensor, mask_tensor = aug.adjustBrightness(image, mask)
            image_Brightness = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'img', str(num) + '_Brightness.jpg'))
            mask_Brightness = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'mask', str(num) + '_Brightness.jpg'))
        if 7 in option:
            num+=1
            image_tensor, mask_tensor = aug.adjustSaturation(image, mask)
            image_Saturation = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'img', str(num) + '_Saturation.jpg'))
            mask_Saturation = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'mask', str(num) + '_Saturation.jpg'))


augmentationData(r'E:\datasets\isbi\train\images',r'E:\datasets\isbi\train\label',save_dir=r'E:\代碼\mytext\suanfa\aug')

運行效果:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章