Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
    optimizer=optimizers.Adam(lr=learning_rate),
    loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
    loss_weights={"main": 0.5, "auxiliary": 0.5},
    metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
   train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

Keras官方文档
generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

  • a tuple (inputs, targets)
  • a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
	"""
	用于产生batch_1, batch_2(记住是numpy.array格式转换)
	"""
	y_batch = {'main':batch_1,'auxiliary':batch_2}
	return  X_batch, y_batch

# or in another way

def batch_generator():
	"""
	用于产生batch_1, batch_2(记住是numpy.array格式转换)
	"""
	yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):
如果是多输出(多任务)的时候,这里的target是字典类型
如果是多输出(多任务)的时候,这里的target是字典类型
如果是多输出(多任务)的时候,这里的target是字典类型

Reference:
[1] How to use fit_generator with multiple outputs in Keras
[2] keras:怎样使用 fit_generator 来训练多个不同类型的输出

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章