本文是關於fit_generator
函數中的generator
屬性中的yield
內容。
什麼?yield是什麼?我也不知道yield呀,生成器呀什麼的[img/眼神閃躲],戳我,我跟你簡單噴噴fit_generator的好處,包含官方文檔,可詳細瞭解
fit_generator(generator=myGenerator(x_train, y_train, batch_size=50),
steps_per_epoch=400,
epochs=30,
validation_data=myGenerator(x_valid, y_valid, batch_size=50),
validation_steps=50)
本人小白。搞了一下午。不停的出錯,時而是傳進去的數據類型不正確。時而是傳進去了,在執行訓練的時候,維度又出現bug。說實在的。雖然是小白,但是在np.array、list、tensor對象之間進行轉換我還是會的。所以,我的主要問題就是:我不知道我要轉換成什麼類型,然後送給yield
弄了大半天。毫無進展,晚上吃完飯回來,把generator的內容全部推翻。嘗試重新寫!我就不信了。
盲猜
猜測一、yield
要的image
和label
肯定都是np.array
類型(猜的)
首先:yield要的肯定是image
和與之對應的label
(假設每次取出50個數據。即batch_size=50
)
其次:這個label
標籤必須已經做過了one-hot
好了,猜完了。。。。。。
那就開始做吧:
本文數據層次組織:
----- train
----------------------class_1
--------------------------------pic_1
--------------------------------pic_2
--------------------------------pic_3
…
--------------------------------pic_n
----------------------class_2
--------------------------------pic_1
--------------------------------pic_2
--------------------------------pic_3
…
--------------------------------pic_n
----------------------class_n
--------------------------------pic_1
--------------------------------pic_2
--------------------------------pic_3
…
--------------------------------pic_n
其他文件夾結構同上
步驟:
(1)、獲取所有的label
標籤和image
[一一對應]
org_train_path = "D:/1/XiongAnDatasets/AID_1/img_all1/train"
org_valid_path = "D:/1/XiongAnDatasets/AID_1/img_all1/valid"
org_test_path = "D:/1/XiongAnDatasets/AID_1/img_all1/test"
# 需要的識別類型
classes = {'Bridge': 0, 'Meadow': 1, 'River': 2, 'Mountain': 3,
'Beach': 4, 'Farmland': 5, 'Forest': 6}
# region 讀取數據(順序經此已經打亂)
def get_files(orig_picture):
class_train = []
label_train = []
for index, name in enumerate(classes):
print(index, name)
class_path = orig_picture + '/' + name
for pic in os.listdir(class_path):
class_train.append(class_path + '/' + pic)
label_train.append(index)
temp = np.array([class_train, label_train])
temp = temp.transpose()
# shuffle the samples
np.random.shuffle(temp)
# after transpose, images is in dimension 0 and label in dimension 1
image_list = list(temp[:, 0])
label_list = list(temp[:, 1])
return image_list, label_list
# endregion
def img2array(img_list):
pre_x = []
for i in range(len(img_list)):
img = cv2.imread(img_list[i])
img_resize = cv2.resize(img, (height, width))
new_img = cv2.cvtColor(img_resize, cv2.COLOR_BGR2RGB)
pre_x.append(new_img) # input一張圖片
pre_x = np.array(pre_x) / 127.5 - 1.0
return pre_x
(2)、調用上述代碼並將得到的label
和image
轉one-hot
和np.array
from keras.utils import to_categorical
# region 劃分並打亂數據
x_train, y_train = get_files(org_train_path)
# list------>np.array--------->one-hot(<-----<------<----看箭頭-----------)
y_train_one_hot = to_categorical(np.array(y_train)) # 引入頭文件
# list------>np.array(<--------<-----<------<----<-------看箭頭-----------)
x_train_new = img2array(x_train)
# endregion
(3)、按照batch_size
進行逐步送入generator
def myGenerator(X_img, Y_label, batch_size=50):
# 傳入的x_img的類型和Y_label是ndarray類型,
# Y_label是已經轉one-hot的了
assert len(X_img) == len(Y_label)
total_size = len(X_img)
while 1:
for i in range(int(total_size / batch_size)):
yield X_img[i * batch_size:(i + 1) * batch_size], Y_label[i * batch_size:(i + 1) * batch_size]
return myGenerator
(4)、訓練調用
# 開始訓練網絡模型
history = model.fit_generator(generator=myGenerator(x_train_new,
y_train_one_hot),
steps_per_epoch=400,
epochs=30,
validation_data=myGenerator(x_valid_new,
y_valid_one_hot),
validation_steps=50)
(5)成。。。。。成。。。。。成功。。。。了。。。
總結:還是先聲明:本人小白。以上代碼雖然成功了。但是我並沒有做到阻止內存爆炸的行爲。執行時,還是會把全部數據一起加載到內存後,再進行訓練。這篇文章的主題主要還是要去了解yield
接收的數據的類型。如果有哪位大神在人羣中多看了俺一眼,關於內存爆炸的問題,還請不吝賜教。感謝!
如有錯誤請指出。