保存檢查點(checkpoint)
艾伯特(http://www.aibbt.com/)國內第一家人工智能門戶爲了得到可以用來後續恢復模型以進一步訓練或評估的檢查點文件(checkpoint file),我們實例化一個tf.train.Saver
。
saver = tf.train.Saver()
在訓練循環中,將定期調用saver.save()
方法,向訓練文件夾中寫入包含了當前所有可訓練變量值得檢查點文件。
saver.save(sess, FLAGS.train_dir, global_step=step)
這樣,我們以後就可以使用saver.restore()
方法,重載模型的參數,繼續訓練。
saver.restore(sess, FLAGS.train_dir)
評估模型
每隔一千個訓練步驟,我們的代碼會嘗試使用訓練數據集與測試數據集,對模型進行評估。do_eval
函數會被調用三次,分別使用訓練數據集、驗證數據集合測試數據集。
print 'Training Data Eval:'
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.train)
print 'Validation Data Eval:'
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.validation)
print 'Test Data Eval:'
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.test)
注意,更復雜的使用場景通常是,先隔絕
data_sets.test
測試數據集,只有在大量的超參數優化調整(hyperparameter tuning)之後才進行檢查。但是,由於MNIST問題比較簡單,我們在這裏一次性評估所有的數據。
構建評估圖表(Eval Graph)
在打開默認圖表(Graph)之前,我們應該先調用get_data(train=False)
函數,抓取測試數據集。
test_all_images, test_all_labels = get_data(train=False)
在進入訓練循環之前,我們應該先調用mnist.py
文件中的evaluation
函數,傳入的logits和標籤參數要與loss
函數的一致。這樣做事爲了先構建Eval操作。
eval_correct = mnist.evaluation(logits, labels_placeholder)
evaluation
函數會生成tf.nn.in_top_k
操作,如果在K個最有可能的預測中可以發現真的標籤,那麼這個操作就會將模型輸出標記爲正確。在本文中,我們把K的值設置爲1,也就是隻有在預測是真的標籤時,才判定它是正確的。
eval_correct = tf.nn.in_top_k(logits, labels, 1)
評估圖表的輸出(Eval Output)
之後,我們可以創建一個循環,往其中添加feed_dict
,並在調用sess.run()
函數時傳入eval_correct
操作,目的就是用給定的數據集評估模型。
for step in xrange(steps_per_epoch):
feed_dict = fill_feed_dict(data_set,
images_placeholder,
labels_placeholder)
true_count += sess.run(eval_correct, feed_dict=feed_dict)
true_count
變量會累加所有in_top_k
操作判定爲正確的預測之和。接下來,只需要將正確測試的總數,除以例子總數,就可以得出準確率了。
precision = float(true_count) / float(num_examples)
print ' Num examples: %d Num correct: %d Precision @ 1: %0.02f' % (
num_examples, true_count, precision)
http://www.aibbt.com/a/16370.html