Keras多輸出(多任務)如何設置fit_generator

在使用Keras的時候,因爲需要考慮到效率問題,需要修改fit_generator來適應多輸出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
    optimizer=optimizers.Adam(lr=learning_rate),
    loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
    loss_weights={"main": 0.5, "auxiliary": 0.5},
    metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
   train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

Keras官方文檔
generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

  • a tuple (inputs, targets)
  • a tuple (inputs, targets, sample_weights).

Keras設計多輸出(多任務)使用fit_generator的步驟如下:

根據官方文檔,定義一個generator或者一個class繼承Sequence

class Batch_generator(Sequence):
	"""
	用於產生batch_1, batch_2(記住是numpy.array格式轉換)
	"""
	y_batch = {'main':batch_1,'auxiliary':batch_2}
	return  X_batch, y_batch

# or in another way

def batch_generator():
	"""
	用於產生batch_1, batch_2(記住是numpy.array格式轉換)
	"""
	yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情說三遍(親自採坑,搜了一大圈才發現滴):
如果是多輸出(多任務)的時候,這裏的target是字典類型
如果是多輸出(多任務)的時候,這裏的target是字典類型
如果是多輸出(多任務)的時候,這裏的target是字典類型

Reference:
[1] How to use fit_generator with multiple outputs in Keras
[2] keras:怎樣使用 fit_generator 來訓練多個不同類型的輸出

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章