存儲模型train.py: model.save('model_weight.h5')
在predict.py中,使用model = load_model("model_weight.h5")對模型進行加載的時報錯信息:
- Unknown Layer: LayerName。此處的LayerName代指自定義的layer。
- global name 'tf' is not defined
正確加載方式:
- 聲明自定義的類,並創建實例。
- model = load_model("model_weight.h5", custom_objects={'tf': tf, 'Self_Attention': Self_Attention_shili, "local_Attention":local_Attention_shili}) ;將自己定義的類的名稱和實例傳進去。
- 如果自定義的類中,存在參數沒有設置初始默認值,則會報錯TypeError: init() missing 1 required positional argument: 'XXX'。解決方法:給一個初始值,需要和訓練時候的參數維度一致。
參考鏈接: