十分鐘看懂 Keras fit_generator multi-loss中的彎彎繞

在訓練神經網絡的過程中發現Keras fit_generator訓練的效果莫名其妙地比我手寫的網絡好

雖然很惱火於沒有找到原因

還是淪陷在Keras的便捷之下

初探fit_generator有些不知所措,完全摸不着頭緒

稍微弄明白一些之後寫了一段示例代碼

相信足夠粗淺易懂

 

import tensorflow as tf
from keras import backend as K
from keras.layers import Conv2D,Activation,Add,Input,Lambda
from keras.models import Model
from keras.losses import mean_absolute_error, mean_squared_error
import numpy as np
def mae(a, b):
    return mean_absolute_error(a, b)

def mse(a, b):
    return mean_squared_error(a, b)

def error(a, b):
    return a

def generator():
    while True:
        in_1=np.array([1])
        in_2=np.array([5])
        in_3=np.array([10])
        pre_1=np.array([2])
        pre_2=np.array([20])
        pre_3=np.array([40])
        yield [in_1,in_2,in_3],[pre_1,pre_2,pre_3]


input_1 =Input(shape=(1,), name='input_1', dtype='float32')

input_2 =Input(shape=(1,), name='input_2', dtype='float32')

input_3 =Input(shape=(1,), name='input_3', dtype='float32')


output_1 = Add(name='output_1_name')([input_1,input_2])

output_2 = Add(name='output_2_name')([input_2,input_3])

output_3 = Add(name='output_3_name')([input_1,input_3])

model=Model([input_1,input_2,input_3], [output_1,output_2,output_3], name='VVD_Model')

model.summary()

model.compile(optimizer='nadam',
              loss={
                  'output_1_name':mae,
                  'output_2_name':mae,
                  'output_3_name':mae,
              },
              metrics={'output_1_name':[mae,mse],
                       'output_2_name':[mae,mse],
                       'output_3_name':[error]})

model.fit_generator(generator(),
                    steps_per_epoch=5,
                    epochs=1,
                    validation_data=generator(),
                    validation_steps=3 ,
                    callbacks=[])
print('end')

'''
   input_1           input_2         input_3                  
      1                 5               10
      |\               /|               /|
      |    \         /  |             /  |
      |        \   /    |           /    |
      |          / \    |         /      |
      |        /       \|       /        |
      |      /          |  \  /          |
      |    /            |   /  \         |
      |  /              | /        \     |
      |/                |              \ |
      +                 +                +
   output_1         output_2        output_3
      6                 15               11
    pre_1              pre_2            pre_3
      2                 20               40
mae |6-2|             |15-20|          |11-40|
      4                 5                29
mse  16                 25               29*29
  
  
運行結果
1/5 [=====>........................] - ETA: 3s - loss: 38.0000 - output_1_name_loss: 4.0000 - output_2_name_loss: 5.0000 - output_3_name_loss: 29.0000 - output_1_name_mae: 4.0000 - output_1_name_mse: 16.0000 - output_2_name_mae: 5.0000 - output_2_name_mse: 25.0000 - output_3_name_error: 40.0000
5/5 [==============================] - 1s 173ms/step - loss: 38.0000 - output_1_name_loss: 4.0000 - output_2_name_loss: 5.0000 - output_3_name_loss: 29.0000 - output_1_name_mae: 4.0000 - output_1_name_mse: 16.0000 - output_2_name_mae: 5.0000 - output_2_name_mse: 25.0000 - output_3_name_error: 40.0000 - val_loss: 38.0000 - val_output_1_name_loss: 4.0000 - val_output_2_name_loss: 5.0000 - val_output_3_name_loss: 29.0000 - val_output_1_name_mae: 4.0000 - val_output_1_name_mse: 16.0000 - val_output_2_name_mae: 5.0000 - val_output_2_name_mse: 25.0000 - val_output_3_name_error: 40.0000
end  
  
'''
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章