深度學習之圖像的數據增強

在圖像的深度學習中,爲了豐富圖像訓練集,更好的提取圖像特徵,泛化模型(防止模型過擬合),一般都會對數據圖像進行數據增強。數據增強,常用的方式,就是旋轉圖像,剪切圖像,改變圖像色差,扭曲圖像特徵,改變圖像尺寸大小,增強圖像噪音(一般使用高斯噪音,鹽椒噪音)等。但是需要注意,不要加入其他圖像輪廓的噪音.

  對於常用的圖像的數據增強的實現,如下:

  1 # -*- coding:utf-8 -*-
  2 """數據增強
  3    1. 翻轉變換 flip
  4    2. 隨機修剪 random crop
  5    3. 色彩抖動 color jittering
  6    4. 平移變換 shift
  7    5. 尺度變換 scale
  8    6. 對比度變換 contrast
  9    7. 噪聲擾動 noise
 10    8. 旋轉變換/反射變換 Rotation/reflection
 11    author: XiJun.Gong
 12    date:2016-11-29
 13 """
 14 
 15 from PIL import Image, ImageEnhance, ImageOps, ImageFile
 16 import numpy as np
 17 import random
 18 import threading, os, time
 19 import logging
 20 
 21 logger = logging.getLogger(__name__)
 22 ImageFile.LOAD_TRUNCATED_IMAGES = True
 23 
 24 
 25 class DataAugmentation:
 26     """
 27     包含數據增強的八種方式
 28     """
 29 
 30 
 31     def __init__(self):
 32         pass
 33 
 34     @staticmethod
 35     def openImage(image):
 36         return Image.open(image, mode="r")
 37 
 38     @staticmethod
 39     def randomRotation(image, mode=Image.BICUBIC):
 40         """
 41          對圖像進行隨機任意角度(0~360度)旋轉
 42         :param mode 鄰近插值,雙線性插值,雙三次B樣條插值(default)
 43         :param image PIL的圖像image
 44         :return: 旋轉轉之後的圖像
 45         """
 46         random_angle = np.random.randint(1, 360)
 47         return image.rotate(random_angle, mode)
 48 
 49     @staticmethod
 50     def randomCrop(image):
 51         """
 52         對圖像隨意剪切,考慮到圖像大小範圍(68,68),使用一個一個大於(36*36)的窗口進行截圖
 53         :param image: PIL的圖像image
 54         :return: 剪切之後的圖像
 55 
 56         """
 57         image_width = image.size[0]
 58         image_height = image.size[1]
 59         crop_win_size = np.random.randint(40, 68)
 60         random_region = (
 61             (image_width - crop_win_size) >> 1, (image_height - crop_win_size) >> 1, (image_width + crop_win_size) >> 1,
 62             (image_height + crop_win_size) >> 1)
 63         return image.crop(random_region)
 64 
 65     @staticmethod
 66     def randomColor(image):
 67         """
 68         對圖像進行顏色抖動
 69         :param image: PIL的圖像image
 70         :return: 有顏色色差的圖像image
 71         """
 72         random_factor = np.random.randint(0, 31) / 10.  # 隨機因子
 73         color_image = ImageEnhance.Color(image).enhance(random_factor)  # 調整圖像的飽和度
 74         random_factor = np.random.randint(10, 21) / 10.  # 隨機因子
 75         brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor)  # 調整圖像的亮度
 76         random_factor = np.random.randint(10, 21) / 10.  # 隨機因1子
 77         contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor)  # 調整圖像對比度
 78         random_factor = np.random.randint(0, 31) / 10.  # 隨機因子
 79         return ImageEnhance.Sharpness(contrast_image).enhance(random_factor)  # 調整圖像銳度
 80 
 81     @staticmethod
 82     def randomGaussian(image, mean=0.2, sigma=0.3):
 83         """
 84          對圖像進行高斯噪聲處理
 85         :param image:
 86         :return:
 87         """
 88 
 89         def gaussianNoisy(im, mean=0.2, sigma=0.3):
 90             """
 91             對圖像做高斯噪音處理
 92             :param im: 單通道圖像
 93             :param mean: 偏移量
 94             :param sigma: 標準差
 95             :return:
 96             """
 97             for _i in range(len(im)):
 98                 im[_i] += random.gauss(mean, sigma)
 99             return im
100 
101         # 將圖像轉化成數組
102         img = np.asarray(image)
103         img.flags.writeable = True  # 將數組改爲讀寫模式
104         width, height = img.shape[:2]
105         img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma)
106         img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma)
107         img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma)
108         img[:, :, 0] = img_r.reshape([width, height])
109         img[:, :, 1] = img_g.reshape([width, height])
110         img[:, :, 2] = img_b.reshape([width, height])
111         return Image.fromarray(np.uint8(img))
112 
113     @staticmethod
114     def saveImage(image, path):
115         image.save(path)
116 
117 
118 def makeDir(path):
119     try:
120         if not os.path.exists(path):
121             if not os.path.isfile(path):
122                 # os.mkdir(path)
123                 os.makedirs(path)
124             return 0
125         else:
126             return 1
127     except Exception, e:
128         print str(e)
129         return -2
130 
131 
132 def imageOps(func_name, image, des_path, file_name, times=5):
133     funcMap = {"randomRotation": DataAugmentation.randomRotation,
134                "randomCrop": DataAugmentation.randomCrop,
135                "randomColor": DataAugmentation.randomColor,
136                "randomGaussian": DataAugmentation.randomGaussian
137                }
138     if funcMap.get(func_name) is None:
139         logger.error("%s is not exist", func_name)
140         return -1
141 
142     for _i in range(0, times, 1):
143         new_image = funcMap[func_name](image)
144         DataAugmentation.saveImage(new_image, os.path.join(des_path, func_name + str(_i) + file_name))
145 
146 
147 opsList = {"randomRotation", "randomCrop", "randomColor", "randomGaussian"}
148 
149 
150 def threadOPS(path, new_path):
151     """
152     多線程處理事務
153     :param src_path: 資源文件
154     :param des_path: 目的地文件
155     :return:
156     """
157     if os.path.isdir(path):
158         img_names = os.listdir(path)
159     else:
160         img_names = [path]
161     for img_name in img_names:
162         print img_name
163         tmp_img_name = os.path.join(path, img_name)
164         if os.path.isdir(tmp_img_name):
165             if makeDir(os.path.join(new_path, img_name)) != -1:
166                 threadOPS(tmp_img_name, os.path.join(new_path, img_name))
167             else:
168                 print 'create new dir failure'
169                 return -1
170                 # os.removedirs(tmp_img_name)
171         elif tmp_img_name.split('.')[1] != "DS_Store":
172             # 讀取文件並進行操作
173             image = DataAugmentation.openImage(tmp_img_name)
174             threadImage = [0] * 5
175             _index = 0
176             for ops_name in opsList:
177                 threadImage[_index] = threading.Thread(target=imageOps,
178                                                        args=(ops_name, image, new_path, img_name,))
179                 threadImage[_index].start()
180                 _index += 1
181                 time.sleep(0.2)
182 
183 
184 if __name__ == '__main__':
185     threadOPS("/home/pic-image/train/12306train",
186               "/home/pic-image/train/12306train3")
鏈接: http://www.cnblogs.com/gongxijun/p/6117588.html
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章