Keras:關於fit_generator中,yeild到底接收的是什麼類型的值?

本文是關於fit_generator函數中的generator屬性中的yield內容。

什麼?yield是什麼?我也不知道yield呀,生成器呀什麼的[img/眼神閃躲],戳我,我跟你簡單噴噴fit_generator的好處,包含官方文檔,可詳細瞭解

fit_generator(generator=myGenerator(x_train, y_train, batch_size=50),
           	  steps_per_epoch=400,
           	  epochs=30,
           	  validation_data=myGenerator(x_valid, y_valid,  batch_size=50),
              validation_steps=50)

本人小白。搞了一下午。不停的出錯,時而是傳進去的數據類型不正確。時而是傳進去了,在執行訓練的時候,維度又出現bug。說實在的。雖然是小白,但是在np.array、list、tensor對象之間進行轉換我還是會的。所以,我的主要問題就是:我不知道我要轉換成什麼類型,然後送給yield
弄了大半天。毫無進展,晚上吃完飯回來,把generator的內容全部推翻。嘗試重新寫!我就不信了。

盲猜

猜測一yield要的imagelabel肯定都是np.array類型(猜的)
首先:yield要的肯定是image和與之對應的label(假設每次取出50個數據。即batch_size=50)
其次:這個label標籤必須已經做過了one-hot

好了,猜完了。。。。。。

那就開始做吧:
本文數據層次組織:

----- train
----------------------class_1
--------------------------------pic_1
--------------------------------pic_2
--------------------------------pic_3

--------------------------------pic_n
----------------------class_2
--------------------------------pic_1
--------------------------------pic_2
--------------------------------pic_3

--------------------------------pic_n
----------------------class_n
--------------------------------pic_1
--------------------------------pic_2
--------------------------------pic_3

--------------------------------pic_n
其他文件夾結構同上

步驟
(1)、獲取所有的label標籤和image [一一對應]

org_train_path = "D:/1/XiongAnDatasets/AID_1/img_all1/train"
org_valid_path = "D:/1/XiongAnDatasets/AID_1/img_all1/valid"
org_test_path = "D:/1/XiongAnDatasets/AID_1/img_all1/test"

# 需要的識別類型
classes = {'Bridge': 0, 'Meadow': 1, 'River': 2, 'Mountain': 3,
           'Beach': 4, 'Farmland': 5, 'Forest': 6}

# region 讀取數據(順序經此已經打亂)
def get_files(orig_picture):
    class_train = []
    label_train = []
    for index, name in enumerate(classes):
        print(index, name)
        class_path = orig_picture + '/' + name
        for pic in os.listdir(class_path):
            class_train.append(class_path + '/' + pic)
            label_train.append(index)
    temp = np.array([class_train, label_train])
    temp = temp.transpose()
    # shuffle the samples
    np.random.shuffle(temp)
    # after transpose, images is in dimension 0 and label in dimension 1
    image_list = list(temp[:, 0])
    label_list = list(temp[:, 1])
    return image_list, label_list
# endregion
def img2array(img_list):
    pre_x = []
    for i in range(len(img_list)):
        img = cv2.imread(img_list[i])
        img_resize = cv2.resize(img, (height, width))
        new_img = cv2.cvtColor(img_resize, cv2.COLOR_BGR2RGB)
        pre_x.append(new_img)  # input一張圖片
    pre_x = np.array(pre_x) / 127.5 - 1.0
    return pre_x

(2)、調用上述代碼並將得到的labelimageone-hotnp.array


from keras.utils import to_categorical

# region 劃分並打亂數據
    x_train, y_train = get_files(org_train_path)
    # list------>np.array--------->one-hot(<-----<------<----看箭頭-----------)
    y_train_one_hot = to_categorical(np.array(y_train))  # 引入頭文件
    # list------>np.array(<--------<-----<------<----<-------看箭頭-----------)
    x_train_new = img2array(x_train)
 # endregion

(3)、按照batch_size進行逐步送入generator

def myGenerator(X_img, Y_label, batch_size=50):
    # 傳入的x_img的類型和Y_label是ndarray類型,
    # Y_label是已經轉one-hot的了
    assert len(X_img) == len(Y_label)
    total_size = len(X_img)
    while 1:
        for i in range(int(total_size / batch_size)):
            yield X_img[i * batch_size:(i + 1) * batch_size], Y_label[i * batch_size:(i + 1) * batch_size]

    return myGenerator

(4)、訓練調用

# 開始訓練網絡模型
history = model.fit_generator(generator=myGenerator(x_train_new,
													y_train_one_hot),
                              steps_per_epoch=400,
                              epochs=30,    
                              validation_data=myGenerator(x_valid_new,
                              							  y_valid_one_hot),
                              validation_steps=50)

(5)成。。。。。成。。。。。成功。。。。了。。。

總結:還是先聲明:本人小白。以上代碼雖然成功了。但是我並沒有做到阻止內存爆炸的行爲。執行時,還是會把全部數據一起加載到內存後,再進行訓練。這篇文章的主題主要還是要去了解yield接收的數據的類型。如果有哪位大神在人羣中多看了俺一眼,關於內存爆炸的問題,還請不吝賜教。感謝!

如有錯誤請指出。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章