TensorFlow:如何保存和載入神經網絡模型?
1.如何保存完成訓練的神經網絡模型?
第一步,在構建的神經網絡的最後添加Saver()。
#添加saver,保存訓練好的神經網絡
saver = tf.train.Saver()
第二步,在執行訓練的結束的最後添加運行S=saver.save()。(路徑最好選擇與.py源文件同一目錄下新建文件夾。)
#保存模型
saver.save(sess, 'net/ conv_net.ckpt')
保存結果如下:
2.如何載入完成訓練的神經網絡模型?
第一步,拷貝.py源文件構建的神經網絡部分全部代碼至新的.py載入文件。同時設置一個tf.train.Saver()。
#添加saver,保存訓練好的神經網絡
saver = tf.train.Saver()
第二步,在默認圖中調用saver.restore()載入函數。
#載入模型
saver.restore(sess,'net/conv_net.ckpt')
第三步,喂數據。
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})
print("Test_accuracy="+str(acc))
測試結果如下: