keras 模型多輸出 loss weight metrics 設置


 

keras支持模型多輸入多輸出,本文記錄多輸出時loss、loss weight和metrics的設置方式。

<!--more-->

模型輸出

假設模型具有多個輸出

  • classify: 二維數組,分類softmax輸出,需要配置交叉熵損失

  • segmentation:與輸入同尺寸map,sigmoid輸出,需要配置二分類損失

  • others:自定義其他輸出,需要自定義損失

具體配置

model

  • 變量均爲模型中網絡層

inputs = [input_1 , input_2]
outputs = [classify, segmentation, others]
model = keras.models.Model(inputs, outputs)

loss

my_loss = {
    'classify': 'categorical_crossentropy',\
    'segmentation':'binary_crossentropy',\
    'others':my_loss_fun}

loss weight

my_loss_weights = {
    'classify':1,\
    'segmentation':1,\
    'others':10}

metrics

my_metrics ={
    'classify':'acc',\
    'segmentation':[mean_iou,'acc'],\
    'others':['mse','acc']
    }

編譯

model.compile(optimizer=Adam(lr=config.LEARNING_RATE), loss=my_loss, loss_weights= my_loss_weights, metrics= my_metrics)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章