在使用Keras的時候,因爲需要考慮到效率問題,需要修改fit_generator來適應多輸出
# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
optimizer=optimizers.Adam(lr=learning_rate),
loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
loss_weights={"main": 0.5, "auxiliary": 0.5},
metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
train_gen, epochs=num_epochs, verbose=0, shuffle=True
)
看Keras官方文檔:
generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either
- a tuple (inputs, targets)
- a tuple (inputs, targets, sample_weights).
Keras設計多輸出(多任務)使用fit_generator的步驟如下:
根據官方文檔,定義一個generator或者一個class繼承Sequence
class Batch_generator(Sequence):
"""
用於產生batch_1, batch_2(記住是numpy.array格式轉換)
"""
y_batch = {'main':batch_1,'auxiliary':batch_2}
return X_batch, y_batch
# or in another way
def batch_generator():
"""
用於產生batch_1, batch_2(記住是numpy.array格式轉換)
"""
yield X_batch, {'main': batch_1,'auxiliary':batch_2}
重要的事情說三遍(親自採坑,搜了一大圈才發現滴):
如果是多輸出(多任務)的時候,這裏的target是字典類型
如果是多輸出(多任務)的時候,這裏的target是字典類型
如果是多輸出(多任務)的時候,這裏的target是字典類型
Reference:
[1] How to use fit_generator with multiple outputs in Keras
[2] keras:怎樣使用 fit_generator 來訓練多個不同類型的輸出