keras支持模型多輸入多輸出,本文記錄多輸出時loss、loss weight和metrics的設置方式。
<!--more-->
模型輸出
假設模型具有多個輸出
-
classify: 二維數組,分類softmax輸出,需要配置交叉熵損失
-
segmentation:與輸入同尺寸map,sigmoid輸出,需要配置二分類損失
-
others:自定義其他輸出,需要自定義損失
具體配置
model
-
變量均爲模型中網絡層
inputs = [input_1 , input_2] outputs = [classify, segmentation, others] model = keras.models.Model(inputs, outputs)
loss
my_loss = { 'classify': 'categorical_crossentropy',\ 'segmentation':'binary_crossentropy',\ 'others':my_loss_fun}
loss weight
my_loss_weights = { 'classify':1,\ 'segmentation':1,\ 'others':10}
metrics
my_metrics ={ 'classify':'acc',\ 'segmentation':[mean_iou,'acc'],\ 'others':['mse','acc'] }
編譯
model.compile(optimizer=Adam(lr=config.LEARNING_RATE), loss=my_loss, loss_weights= my_loss_weights, metrics= my_metrics)