Tensorflow2.0 保存和加載模型的幾種方法

零、綜述

    1. save/load weights
    1. save/load entire model
    1. saved_model

一、Save the weights

1.一次性保存所有參數

model.save_weights('./checkpoints/my_checkpoint') 

2.加載權重

注意,用該方法保存模型只保存了參數,文件較小,加載較快,但是測試/部署時需要重建搭建網絡。

model = create_model() #定義網絡框架
model.load_weights('./checkpoints/my_checkpoint') #加載訓練好的權重

loss, acc = model.evaluate(test_images, test_labels)

network.save_weights('weights.ckpt') #保存權重
print('saved weights')
del network

network = Sequential([layers.Dense(256)...])#模型必須跟訓練時參數一模一樣
network.compile(optimizer=optimizer.Adam(lr=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
network.load_weights('weights.ckpt') #加載後可從檢查點處繼續訓練
network.evaluate(ds_val)

二.Save the model

該方法把模型也保存了,文件較大,效率比較低。

#保存模型和參數
network.save('model.h5')
#刪除模型和參數
del network
#重新加載模型和參數
network = tf.keras.models.load_model('model.h5')
network.evaluate(x_val, y_val)

#三、ONNX
保存爲onnx,這是通用格式,python生成的可以用c++解析,一般python訓練而用C++部署。
注意,ONNX可以轉TensorRT,以部署到NVIDIA的嵌入式設備中.


tf.saved_model.save(m, '/tmp/saved_model/') #可以給其餘語言使用的

imported = tf.saved_model.load(path) #直接Load
f = imported.signatures["serving_default"]
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章