一文詳細綜述數據增強方法(附代碼)

導讀

在深度學習時代,數據的規模越大、質量越高,模型就能夠擁有更好的泛化能力,數據直接決定了模型學習的上限。然而在實際工程中,採集的數據很難覆蓋全部的場景,比如圖像的光照條件,同一場景拍攝的圖片可能由於光線不同就會有很大的差異性,那麼在訓練模型的時候就需要加入光照方面的數據增強。另一方面,即使擁有大量的數據,也應該進行數據增強,這樣有助於添加相關數據數據集中數據的數量,防止模型學習到不想要的模型,避免出現過擬合現象。

數據增強的具體使用方法有兩種,一種是事先執行所有的轉換,實質是增強數據集的大小,這種方法稱爲線下增強。它比較適用於較小的數據集,最終將增加一定倍數的數據量,這個倍數取決於轉換的圖片個數,比如我需要對所有的圖片進行旋轉,則數據量增加一倍,本文中討論的就是該方法。另一種是在將數據送入到機器學習模型的時候小批量(mini-batch)的轉換,這種方法被稱爲線上增強或者飛行增強。這種方法比較適用於大數據集合,pytorch中的transforms函數就是基於該方法,在訓練中每次對原始圖像進行擾動處理,經過多次幾輪(epoch)訓練之後,就等效於數據增加。

常用的數據增強有兩種,有監督和無監督兩種。本文只探討有監督數據增強。有監督數據增強是基於現有的數據集,通過分析數據的完備性,採用一定的規則對現有數據進行擴充。有監督數據增強可以細分爲單樣本數據增強和多樣本數據增強,在實際工程應用中,單樣本數據增強使用更多,在git上有一些性能較好開源數據增強項目,他們功能較全並且處理速度也很快,開發者可以直接調用,如imgaug和albumentations。在pytorch中,可以通過torchvision的transforms模塊來實現集成了很多數據增強函數包。本文主要介紹單樣本數據增強的一些常用方法。

1.裁剪

做裁剪操作主要是考慮原始圖像的寬高擾動,在大多數圖像分類網絡中,樣本在輸入網絡前必須要統一大小,所以通過調整圖像的尺寸可以大量的擴展數據。通過裁剪有兩種擴種方式,一種是對大尺寸的圖像直接按照需要送入網絡的尺寸進行裁剪,比如原始圖像的分辨率大小是256x256,現在網絡需要輸入的圖像像素尺寸是224x224,這樣可以直接在原始圖像上進行隨機裁剪224x224 像素大小的圖像即可,這樣一張圖可以擴充32x32張圖片;另外一種是將隨機裁剪固定尺寸大小的圖片,然後再將圖像通過插值算法調整到網絡需要的尺寸大小。由於數據集中通常數據大小不一,後者通常使用的較多。

使用opencv進行圖像裁剪,利用隨機數確定圖像的裁剪範圍,代碼如下:

img_path = '../../img/ch3_img1.jpg'
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w, _ = img.shape
new_h1, new_h2 = np.random.randint(0, h-512, 2)
new_w1, new_w2 = np.random.randint(0, w-512, 2)
img_crop1 = img[new_h1:new_h1+512, new_w1:new_w1+512, :]
img_crop2 = img[new_h2:new_h2+512, new_w2:new_w2+512, :]

# 顯示
plt.figure(figsize=(15, 10))
plt.subplot(1,3,1), plt.imshow(img)
plt.axis('off'); plt.title('原圖')
plt.subplot(1,3,2), plt.imshow(img_crop1)
plt.axis('off'); plt.title('水平鏡像')
plt.subplot(1,3,3), plt.imshow(img_crop2)
plt.axis('off'); plt.title('垂直鏡像')
plt.show()

運行上述代碼可到如下結果圖:

image

2.翻轉和旋轉

翻轉和旋轉都是將原始的圖像像素在位置空間上做變換,圖像的翻轉是將原始的圖像進行鏡像操作,鏡像操作在數據增強中會經常被使用,並且起了非常重要的作用,它主要包括水平鏡像翻轉,垂直鏡像翻轉和原點鏡像翻轉,具體在使用中,需要結合數據形式選擇相應翻轉操作,比如數據集是汽車圖像數據,訓練集合測試集都是正常拍攝的圖片,此時只使用水平鏡像操作,如果加入垂直或者原點鏡像翻轉,會對原始圖像產生干擾。

角度旋轉操作和圖像鏡像相對,它主要是沿着畫面的中心進行任意角度的變換,該變換是通過將原圖像和仿射變換矩陣相乘實現的。爲了實現圖像的中心旋轉,除了要知道旋轉角度,還要計算平移的量才能能讓仿射變換的效果等效於旋轉軸的畫面中心。仿射變換矩陣是一個餘弦矩陣,在OpenCV中有實現的庫cv2.getRotationMatrix2D可以使用,該函數的第1個參數是旋轉中心,第2個參數是逆時針旋轉角度,第3個參數是縮放倍數,對於只是旋轉的情況參數值是1,返回的值就是做仿射變換的矩陣。然後通過cv2.warpAffine()將原圖像矩陣乘以旋轉矩陣得到最終的結果。

通過上述的操作旋轉的圖像會有存在黑邊,如果想去除掉圖片的黑邊,需要將原始的圖像做出一些犧牲,對旋轉後的圖像取最大內接矩陣,該矩陣的長寬比和原始圖像相同,如圖2-35所示。要計算內切矩陣的座標Q,需要通過旋轉角度  和原始圖像矩陣的邊長OP得到。

img

最終的計算公式如公式(1)至(3)所示

利用opencv實現的代碼如下:

# 去除黑邊的操作
crop_image = lambda img, x0, y0, w, h: img[y0:y0+h, x0:x0+w] # 定義裁切函數,後續裁切黑邊使用

def rotate_image(img, angle, crop):
"""
angle: 旋轉的角度
crop: 是否需要進行裁剪,布爾向量
"""

w, h = img.shape[:2]
# 旋轉角度的週期是360°
angle %= 360
# 計算仿射變換矩陣
M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
# 得到旋轉後的圖像
img_rotated = cv2.warpAffine(img, M_rotation, (w, h))

# 如果需要去除黑邊
if crop:
# 裁剪角度的等效週期是180°
angle_crop = angle % 180
if angle > 90:
angle_crop = 180 - angle_crop
# 轉化角度爲弧度
theta = angle_crop * np.pi / 180
# 計算高寬比
hw_ratio = float(h) / float(w)
# 計算裁剪邊長係數的分子項
tan_theta = np.tan(theta)
numerator = np.cos(theta) + np.sin(theta) * np.tan(theta)

# 計算分母中和高寬比相關的項
r = hw_ratio if h > w else 1 / hw_ratio
# 計算分母項
denominator = r * tan_theta + 1
# 最終的邊長係數
crop_mult = numerator / denominator

# 得到裁剪區域
w_crop = int(crop_mult * w)
h_crop = int(crop_mult * h)
x0 = int((w - w_crop) / 2)
y0 = int((h - h_crop) / 2)
img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop)
return img_rotated
#水平鏡像
h_flip = cv2.flip(img,1)
#垂直鏡像
v_flip = cv2.flip(img,0)
#水平垂直鏡像
hv_flip = cv2.flip(img,-1)
#90度旋轉
rows, cols, _ = img.shape
M = cv2.getRotationMatrix2D((cols/2, rows/2), 45, 1)
rotation_45 = cv2.warpAffine(img, M, (cols, rows))
#45度旋轉
M = cv2.getRotationMatrix2D((cols/2, rows/2), 135, 2)
rotation_135 = cv2.warpAffine(img, M,(cols, rows))
#去黑邊旋轉45度
image_rotated = rotate_image(img, 45, True)

#顯示
plt.figure(figsize=(15, 10))
plt.subplot(2,3,1), plt.imshow(img)
plt.axis('off'); plt.title('原圖')
plt.subplot(2,3,2), plt.imshow(h_flip)
plt.axis('off'); plt.title('水平鏡像')
plt.subplot(2,3,3), plt.imshow(v_flip)
plt.axis('off'); plt.title('垂直鏡像')
plt.subplot(2,3,4), plt.imshow(hv_flip)
plt.axis('off'); plt.title('水平垂直鏡像')
plt.subplot(2,3,5), plt.imshow(rotation_45)
plt.axis('off'); plt.title('旋轉45度')
plt.subplot(2,3,6), plt.imshow(image_rotated)
plt.axis('off'); plt.title('去黑邊旋轉45度')
plt.show()

上述代碼通過opencv自帶的flip函數實現了翻轉操作,該函數的第二個參數是控制翻轉的方向。通過內切矩陣計算公式可得無黑邊剪切結果。上述代碼的結果圖如圖2-36所示。通過結果可以看出,旋轉

img

3. 縮放

圖像可以向外或向內縮放。向外縮放時,最終圖像尺寸將大於原始圖像尺寸,爲了保持原始圖像的大小,通常需要結合裁剪,從縮放後的圖像中裁剪出和原始圖像大小一樣的圖像。另一種方法是向內縮放,它會縮小圖像大小,縮小到預設的大小。縮放也會帶了一些問題,如縮放後的圖像尺寸和原始圖像尺寸的長寬比差異較大,會出現圖像失幀的現象,如果在實驗中對最終的結果有一定的影響,需要做等比例縮放,對不足的地方進行邊緣填充。以下是縮放的代碼和結果圖像。

img_2 = cv2.resize(img, (int(h * 1.5), int(w * 1.5)))
img_2 = img_2[int((h - 512) / 2) : int((h + 512) / 2), int((w - 512) / 2) : int((w + 512) /2), :]
img_3 = cv2.resize(img, (512, 512))

## 顯示
plt.figure(figsize=(15, 10))
plt.subplot(1,3,1), plt.imshow(img)
plt.axis('off'); plt.title('原圖')
plt.subplot(1,3,2), plt.imshow(img_2)
plt.axis('off'); plt.title('向外縮放')
plt.subplot(1,3,3), plt.imshow(img_3)
plt.axis('off'); plt.title('向內縮放')
plt.show()
img

4.移位

移位只涉及沿X或Y方向(或兩者)移動圖像,如果圖像的背景是單色被背景或者是純的黑色背景,使用該方法可以很有效的增強數據數量,可以通過cv2.warpAffine實現該代碼

mat_shift = np.float32([[1,0,100], [0,1,200]])
img_1 = cv2.warpAffine(img, mat_shift, (h, w))
mat_shift = np.float32([[1, 0, -150], [0, 1, -150]])
img_2 = cv2.warpAffine(img, mat_shift, (h, w))

## 顯示
plt.figure(figsize=(15, 10))
plt.subplot(1,3,1), plt.imshow(img)
plt.axis('off'); plt.title('原圖')
plt.subplot(1,3,2), plt.imshow(img_1)
plt.axis('off'); plt.title('向右下移動')
plt.subplot(1,3,3), plt.imshow(img_2)
plt.axis('off'); plt.title('左上移動')
plt.show()
img

5. 高斯噪聲

當神經網絡試圖學習可能無用的高頻特徵(即圖像中大量出現的模式)時,通常會發生過度擬合。具有零均值的高斯噪聲基本上在所有頻率中具有數據點,從而有效地扭曲高頻特徵。這也意味着較低頻率的組件也會失真,但你的神經網絡可以學會超越它,添加適量的噪音可以增強學習能力。

基於噪聲的數據增強就是在原來的圖片的基礎上,隨機疊加一些噪聲,最常見的做法就是高斯噪聲。更復雜一點的就是在面積大小可選定、位置隨機的矩形區域上丟棄像素產生黑色矩形塊,從而產生一些彩色噪聲,以Coarse Dropout方法爲代表,甚至還可以對圖片上隨機選取一塊區域並擦除圖像信息。以下是代碼和圖像:

img_s1 = gasuss_noise(img, 0, 0.005)
img_s2 = gasuss_noise(img, 0, 0.05)
plt.figure(figsize=(15, 10))
plt.subplot(1,3,1), plt.imshow(img)
plt.axis('off'); plt.title('原圖')
plt.subplot(1,3,2), plt.imshow(img_s1)
plt.axis('off'); plt.title('方差爲0.005')
plt.subplot(1,3,3), plt.imshow(img_s2)
plt.axis('off'); plt.title('方差爲0.05')
plt.show()
img

6.色彩抖動

上面提到的圖像中有一個比較大的難點是背景干擾,在實際工程中爲了消除圖像在不同背景中存在的差異性,通常會做一些色彩抖動操作,擴充數據集合。色彩抖動主要是在圖像的顏色方面做增強,主要調整的是圖像的亮度,飽和度和對比度。工程中不是任何數據集都適用,通常如果不同背景的圖像較多,加入色彩抖動操作會有很好的提升。

def randomColor(image, saturation=0, brightness=0, contrast=0, sharpness=0):
if random.random() < saturation:
random_factor = np.random.randint(0, 31) / 10. # 隨機因子
image = ImageEnhance.Color(image).enhance(random_factor) # 調整圖像的飽和度
if random.random() < brightness:
random_factor = np.random.randint(10, 21) / 10. # 隨機因子
image = ImageEnhance.Brightness(image).enhance(random_factor) # 調整圖像的亮度
if random.random() < contrast:
random_factor = np.random.randint(10, 21) / 10. # 隨機因1子
image = ImageEnhance.Contrast(image).enhance(random_factor) # 調整圖像對比度
if random.random() < sharpness:
random_factor = np.random.randint(0, 31) / 10. # 隨機因子
ImageEnhance.Sharpness(image).enhance(random_factor) # 調整圖像銳度
return image


cj_img = Image.fromarray(img)
sa_img = np.asarray(randomColor(cj_img, saturation=1))
br_img = np.asarray(randomColor(cj_img, brightness=1))
co_img = np.asarray(randomColor(cj_img, contrast=1))
sh_img = np.asarray(randomColor(cj_img, sharpness=1))
rc_img = np.asarray(randomColor(cj_img, saturation=1, \
brightness=1, contrast=1, sharpness=1))
plt.figure(figsize=(15, 10))
plt.subplot(2,3,1), plt.imshow(img)
plt.axis('off'); plt.title('原圖')
plt.subplot(2,3,2), plt.imshow(sa_img)
plt.axis('off'); plt.title('調整飽和度')
plt.subplot(2,3,3), plt.imshow(br_img)
plt.axis('off'); plt.title('調整亮度')
plt.subplot(2,3,4), plt.imshow(co_img)
plt.axis('off'); plt.title('調整對比度')
plt.subplot(2,3,5), plt.imshow(sh_img)
plt.axis('off'); plt.title('調整銳度')
plt.subplot(2,3,6), plt.imshow(rc_img)
plt.axis('off'); plt.title('調整所有項')
plt.show()
img

參考文獻

  • https://www.cnblogs.com/fydeblog/p/10734733.html

  • https://blog.csdn.net/zhu_hongji/article/details/80984341

  • http://docs.opencv.org/master/d4/d13/tutorial_py_filtering.html

  • 叶韻.深度學習與計算機視覺:算法原理、框架應用與代碼實現[M].北京:機械工業出版社,2018:194-195.

  • Zhang H, Cisse M, Dauphin Y N, et al. mixup: Beyond Empirical Risk Minimization[J]. 2017.

  • https://medium.com/nanonets/how-to-use-deep-learning-when-you-have-limited-data-part-2-data-augmentation-c26971dc8ced

  • https://medium.com/nanonets/nanonets-how-to-use-deep-learning-when-you-have-limited-data-f68c0b512cab

精彩內容推薦

歡迎熱愛打比賽的或者熱愛計算機視覺的小夥伴加羣交流,掃碼添加微信,備註加羣即可。

喜歡的話歡迎在看和轉發

本文分享自微信公衆號 - AI成長社(ai-growth)。
如有侵權,請聯繫 [email protected] 刪除。
本文參與“OSC源創計劃”,歡迎正在閱讀的你也加入,一起分享。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章