TensorFlow如何保存和載入神經網絡模型?

TensorFlow:如何保存和載入神經網絡模型?

1.如何保存完成訓練的神經網絡模型?

​ 第一步,在構建的神經網絡的最後添加Saver()。

#添加saver,保存訓練好的神經網絡
saver = tf.train.Saver()

​ 第二步,在執行訓練的結束的最後添加運行S=saver.save()。(路徑最好選擇與.py源文件同一目錄下新建文件夾。)


#保存模型
saver.save(sess, 'net/ conv_net.ckpt')

​ 保存結果如下:

在這裏插入圖片描述

2.如何載入完成訓練的神經網絡模型?

​ 第一步,拷貝.py源文件構建的神經網絡部分全部代碼至新的.py載入文件。同時設置一個tf.train.Saver()。

#添加saver,保存訓練好的神經網絡
saver = tf.train.Saver()

​ 第二步,在默認圖中調用saver.restore()載入函數。

#載入模型
saver.restore(sess,'net/conv_net.ckpt')

​ 第三步,喂數據。

 acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})
 print("Test_accuracy="+str(acc))

​ 測試結果如下:

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章