前言
Keras
中有一個圖像數據處理器ImageDataGenerator
,能夠很方便地進行數據增強,並且從文件中批量加載圖片,避免數據集過大時,一下子加載進內存會崩掉。但是從官方文檔發現,並沒有一個比較重要的圖像增強方式:隨機裁剪,本博客就是記錄一下如何在對ImageDataGenerator
中生成的batch做圖像裁剪
國際慣例,參考博客:
Keras 在fit_generator訓練方式中加入圖像random_crop
Extending Keras’ ImageDataGenerator to Support Random Cropping
第二個博客比較全,第三個博客只介紹了分類數據的增強,如果是圖像分割或者超分辨率,輸出仍是一張圖像,所以涉及到對image
和mask
進行同步增強
代碼
先介紹一下數據集目錄結構:
在test
文件夾下,分別有GT
和NGT
兩個文件夾,每個文件夾存儲的都是bmp
圖像文件
其次需要注意,從ImageDataGenerator
中取數據用的是next(generator)
函數
-
載入相關包
from keras_preprocessing.image import ImageDataGenerator import matplotlib.pyplot as plt import numpy as np
-
先使用自帶的
ImageDataGenerator
配合flow_from_director
讀取數據
創建生成器train_img_datagen=ImageDataGenerator()#各種預處理 train_mask_datagen=ImageDataGenerator()#各種預處理
讀取文件
seed=2 #圖像會隨機打亂即shuffle,但是輸入和輸出的打亂順序必須一樣 batch_size=2 target_size=(1080,1920) train_img_gen=train_img_datagen.flow_from_directory('./test',classes=['NGT'], class_mode=None, batch_size=batch_size, target_size=target_size, shuffle=True, seed=seed, interpolation='bicubic') train_mask_gen=train_img_datagen.flow_from_directory('./test', classes=['GT'], class_mode=None, batch_size=batch_size, target_size=target_size, shuffle=True, seed=seed, interpolation='bicubic')
封裝打包
train_generator=zip(train_img_gen,train_mask_gen)
-
定義裁剪器,裁剪圖像和對應的mask:
def crop_generator(batch_gen,crop_size=(270,480)): while True: batch_x,batch_y=next(batch_gen) crops_img=np.zeros((batch_x.shape[0],crop_size[0],crop_size[1],3)) crops_mask=np.zeros((batch_y.shape[0],crop_size[0],crop_size[1],3)) height,width=batch_x.shape[1],batch_x.shape[2] for i in range(batch_x.shape[0]): #裁剪圖像 x=np.random.randint(0,height-crop_size[0]+1) y=np.random.randint(0,width-crop_size[1]+1) crops_img[i]=batch_x[i,x:x+crop_size[0],y:y+crop_size[1]] crops_mask[i]=batch_y[i,x:x+crop_size[0],y:y+crop_size[1]] yield (crops_img,crops_mask)
-
使用裁剪器對
Generator
進行裁剪train_crops=crop_generator(train_generator)
可視化:
img,mask=next(train_crops)
print(img.shape)
plt.subplot(2,1,1)
plt.imshow(img[0]/255)
plt.subplot(2,1,2)
plt.imshow(mask[0]/255)
後記
記住要用while(True)
死循環,並且yield
在while
循環內部,和for
循環外部,代表每個批次
代碼:
鏈接:https://pan.baidu.com/s/1UNZLke5kygBFHJ8iR8wV2A
提取碼:e51e