在圖像的深度學習中,爲了豐富圖像訓練集,更好的提取圖像特徵,泛化模型(防止模型過擬合),一般都會對數據圖像進行數據增強。數據增強,常用的方式,就是旋轉圖像,剪切圖像,改變圖像色差,扭曲圖像特徵,改變圖像尺寸大小,增強圖像噪音(一般使用高斯噪音,鹽椒噪音)等。但是需要注意,不要加入其他圖像輪廓的噪音.
對於常用的圖像的數據增強的實現,如下:
1 # -*- coding:utf-8 -*-
2 """數據增強
3 1. 翻轉變換 flip
4 2. 隨機修剪 random crop
5 3. 色彩抖動 color jittering
6 4. 平移變換 shift
7 5. 尺度變換 scale
8 6. 對比度變換 contrast
9 7. 噪聲擾動 noise
10 8. 旋轉變換/反射變換 Rotation/reflection
11 author: XiJun.Gong
12 date:2016-11-29
13 """
14
15 from PIL import Image, ImageEnhance, ImageOps, ImageFile
16 import numpy as np
17 import random
18 import threading, os, time
19 import logging
20
21 logger = logging.getLogger(__name__)
22 ImageFile.LOAD_TRUNCATED_IMAGES = True
23
24
25 class DataAugmentation:
26 """
27 包含數據增強的八種方式
28 """
29
30
31 def __init__(self):
32 pass
33
34 @staticmethod
35 def openImage(image):
36 return Image.open(image, mode="r")
37
38 @staticmethod
39 def randomRotation(image, mode=Image.BICUBIC):
40 """
41 對圖像進行隨機任意角度(0~360度)旋轉
42 :param mode 鄰近插值,雙線性插值,雙三次B樣條插值(default)
43 :param image PIL的圖像image
44 :return: 旋轉轉之後的圖像
45 """
46 random_angle = np.random.randint(1, 360)
47 return image.rotate(random_angle, mode)
48
49 @staticmethod
50 def randomCrop(image):
51 """
52 對圖像隨意剪切,考慮到圖像大小範圍(68,68),使用一個一個大於(36*36)的窗口進行截圖
53 :param image: PIL的圖像image
54 :return: 剪切之後的圖像
55
56 """
57 image_width = image.size[0]
58 image_height = image.size[1]
59 crop_win_size = np.random.randint(40, 68)
60 random_region = (
61 (image_width - crop_win_size) >> 1, (image_height - crop_win_size) >> 1, (image_width + crop_win_size) >> 1,
62 (image_height + crop_win_size) >> 1)
63 return image.crop(random_region)
64
65 @staticmethod
66 def randomColor(image):
67 """
68 對圖像進行顏色抖動
69 :param image: PIL的圖像image
70 :return: 有顏色色差的圖像image
71 """
72 random_factor = np.random.randint(0, 31) / 10. # 隨機因子
73 color_image = ImageEnhance.Color(image).enhance(random_factor) # 調整圖像的飽和度
74 random_factor = np.random.randint(10, 21) / 10. # 隨機因子
75 brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor) # 調整圖像的亮度
76 random_factor = np.random.randint(10, 21) / 10. # 隨機因1子
77 contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor) # 調整圖像對比度
78 random_factor = np.random.randint(0, 31) / 10. # 隨機因子
79 return ImageEnhance.Sharpness(contrast_image).enhance(random_factor) # 調整圖像銳度
80
81 @staticmethod
82 def randomGaussian(image, mean=0.2, sigma=0.3):
83 """
84 對圖像進行高斯噪聲處理
85 :param image:
86 :return:
87 """
88
89 def gaussianNoisy(im, mean=0.2, sigma=0.3):
90 """
91 對圖像做高斯噪音處理
92 :param im: 單通道圖像
93 :param mean: 偏移量
94 :param sigma: 標準差
95 :return:
96 """
97 for _i in range(len(im)):
98 im[_i] += random.gauss(mean, sigma)
99 return im
100
101 # 將圖像轉化成數組
102 img = np.asarray(image)
103 img.flags.writeable = True # 將數組改爲讀寫模式
104 width, height = img.shape[:2]
105 img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma)
106 img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma)
107 img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma)
108 img[:, :, 0] = img_r.reshape([width, height])
109 img[:, :, 1] = img_g.reshape([width, height])
110 img[:, :, 2] = img_b.reshape([width, height])
111 return Image.fromarray(np.uint8(img))
112
113 @staticmethod
114 def saveImage(image, path):
115 image.save(path)
116
117
118 def makeDir(path):
119 try:
120 if not os.path.exists(path):
121 if not os.path.isfile(path):
122 # os.mkdir(path)
123 os.makedirs(path)
124 return 0
125 else:
126 return 1
127 except Exception, e:
128 print str(e)
129 return -2
130
131
132 def imageOps(func_name, image, des_path, file_name, times=5):
133 funcMap = {"randomRotation": DataAugmentation.randomRotation,
134 "randomCrop": DataAugmentation.randomCrop,
135 "randomColor": DataAugmentation.randomColor,
136 "randomGaussian": DataAugmentation.randomGaussian
137 }
138 if funcMap.get(func_name) is None:
139 logger.error("%s is not exist", func_name)
140 return -1
141
142 for _i in range(0, times, 1):
143 new_image = funcMap[func_name](image)
144 DataAugmentation.saveImage(new_image, os.path.join(des_path, func_name + str(_i) + file_name))
145
146
147 opsList = {"randomRotation", "randomCrop", "randomColor", "randomGaussian"}
148
149
150 def threadOPS(path, new_path):
151 """
152 多線程處理事務
153 :param src_path: 資源文件
154 :param des_path: 目的地文件
155 :return:
156 """
157 if os.path.isdir(path):
158 img_names = os.listdir(path)
159 else:
160 img_names = [path]
161 for img_name in img_names:
162 print img_name
163 tmp_img_name = os.path.join(path, img_name)
164 if os.path.isdir(tmp_img_name):
165 if makeDir(os.path.join(new_path, img_name)) != -1:
166 threadOPS(tmp_img_name, os.path.join(new_path, img_name))
167 else:
168 print 'create new dir failure'
169 return -1
170 # os.removedirs(tmp_img_name)
171 elif tmp_img_name.split('.')[1] != "DS_Store":
172 # 讀取文件並進行操作
173 image = DataAugmentation.openImage(tmp_img_name)
174 threadImage = [0] * 5
175 _index = 0
176 for ops_name in opsList:
177 threadImage[_index] = threading.Thread(target=imageOps,
178 args=(ops_name, image, new_path, img_name,))
179 threadImage[_index].start()
180 _index += 1
181 time.sleep(0.2)
182
183
184 if __name__ == '__main__':
185 threadOPS("/home/pic-image/train/12306train",
186 "/home/pic-image/train/12306train3")
鏈接: http://www.cnblogs.com/gongxijun/p/6117588.html