BUG
在使用Keras訓練模型時,在每個epoch完成後save_model時會報錯 “AttributeError: 'NoneType' object has no attribute 'update'”
具體異常打印信息如下,主要原因是模型中有自定義的class,Keras不知道怎麼進行deep_copy()
File "train.py", line 88, in <module>
build_model()
File "train.py", line 80, in build_model
CSVLogger(log_path),
File "/usr/python/lib/python3.5/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/usr/python/lib/python3.5/site-packages/keras/engine/training.py", line 2268, in fit_generator
callbacks.on_epoch_end(epoch, epoch_logs)
File "/usr/python/lib/python3.5/site-packages/keras/callbacks.py", line 77, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "/usr/python/lib/python3.5/site-packages/keras/callbacks.py", line 447, in on_epoch_end
self.model.save(filepath, overwrite=True)
File "/usr/python/lib/python3.5/site-packages/keras/engine/topology.py", line 2591, in save
save_model(self, filepath, overwrite, include_optimizer)
File "/usr/python/lib/python3.5/site-packages/keras/models.py", line 126, in save_model
'config': model.get_config()
File "/usr/python/lib/python3.5/site-packages/keras/engine/topology.py", line 2432, in get_config
return copy.deepcopy(config)
解決辦法
方法一
在回調函數callbacks中加入save_weights_only=True,加載模型時new Model(), 然後載入weights,這樣避免deep_copy()
callbacks=[ModelCheckpoint(model_path, save_weights_only=True,
monitor='val_loss', mode='min', save_best_only=True)]
方法二
對無法deep_copy()的class,自定義複製方法。
class YourClass(object):
# ...
def __deepcopy__(self):
return Your_deep_copy()