在訓練神經網絡的過程中發現Keras fit_generator訓練的效果莫名其妙地比我手寫的網絡好
雖然很惱火於沒有找到原因
還是淪陷在Keras的便捷之下
初探fit_generator有些不知所措,完全摸不着頭緒
稍微弄明白一些之後寫了一段示例代碼
相信足夠粗淺易懂
import tensorflow as tf
from keras import backend as K
from keras.layers import Conv2D,Activation,Add,Input,Lambda
from keras.models import Model
from keras.losses import mean_absolute_error, mean_squared_error
import numpy as np
def mae(a, b):
return mean_absolute_error(a, b)
def mse(a, b):
return mean_squared_error(a, b)
def error(a, b):
return a
def generator():
while True:
in_1=np.array([1])
in_2=np.array([5])
in_3=np.array([10])
pre_1=np.array([2])
pre_2=np.array([20])
pre_3=np.array([40])
yield [in_1,in_2,in_3],[pre_1,pre_2,pre_3]
input_1 =Input(shape=(1,), name='input_1', dtype='float32')
input_2 =Input(shape=(1,), name='input_2', dtype='float32')
input_3 =Input(shape=(1,), name='input_3', dtype='float32')
output_1 = Add(name='output_1_name')([input_1,input_2])
output_2 = Add(name='output_2_name')([input_2,input_3])
output_3 = Add(name='output_3_name')([input_1,input_3])
model=Model([input_1,input_2,input_3], [output_1,output_2,output_3], name='VVD_Model')
model.summary()
model.compile(optimizer='nadam',
loss={
'output_1_name':mae,
'output_2_name':mae,
'output_3_name':mae,
},
metrics={'output_1_name':[mae,mse],
'output_2_name':[mae,mse],
'output_3_name':[error]})
model.fit_generator(generator(),
steps_per_epoch=5,
epochs=1,
validation_data=generator(),
validation_steps=3 ,
callbacks=[])
print('end')
'''
input_1 input_2 input_3
1 5 10
|\ /| /|
| \ / | / |
| \ / | / |
| / \ | / |
| / \| / |
| / | \ / |
| / | / \ |
| / | / \ |
|/ | \ |
+ + +
output_1 output_2 output_3
6 15 11
pre_1 pre_2 pre_3
2 20 40
mae |6-2| |15-20| |11-40|
4 5 29
mse 16 25 29*29
運行結果
1/5 [=====>........................] - ETA: 3s - loss: 38.0000 - output_1_name_loss: 4.0000 - output_2_name_loss: 5.0000 - output_3_name_loss: 29.0000 - output_1_name_mae: 4.0000 - output_1_name_mse: 16.0000 - output_2_name_mae: 5.0000 - output_2_name_mse: 25.0000 - output_3_name_error: 40.0000
5/5 [==============================] - 1s 173ms/step - loss: 38.0000 - output_1_name_loss: 4.0000 - output_2_name_loss: 5.0000 - output_3_name_loss: 29.0000 - output_1_name_mae: 4.0000 - output_1_name_mse: 16.0000 - output_2_name_mae: 5.0000 - output_2_name_mse: 25.0000 - output_3_name_error: 40.0000 - val_loss: 38.0000 - val_output_1_name_loss: 4.0000 - val_output_2_name_loss: 5.0000 - val_output_3_name_loss: 29.0000 - val_output_1_name_mae: 4.0000 - val_output_1_name_mse: 16.0000 - val_output_2_name_mae: 5.0000 - val_output_2_name_mse: 25.0000 - val_output_3_name_error: 40.0000
end
'''