兩種辦法
- Checkpoint
- Protocol_buffer
區別:前者可以繼續訓練,後者只能預測。
Checkpoint_save
# 定義Saver用於保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
# 變量初始化
sess.run(init)
# 運行11個週期
for epoch in range(11):
for batch in range(n_batch):
# 獲取一個批次的數據和標籤
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
# 喂到模型中做訓練
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
# 每個週期計算一次測試集準確率
acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
# 打印信息
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
# 保存模型
saver.save(sess,'models/my_model.ckpt')
Checkpoint_restore1(有結構有參數)
# 載入數據集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
# 定義批次大小
batch_size = 64
# 計算一共有多少個批次
n_batch = mnist.train.num_examples // batch_size
with tf.Session() as sess:
# 載入模型結構
saver = tf.train.import_meta_graph('models/my_model.ckpt.meta')
# 載入模型參數
saver.restore(sess,'models/my_model.ckpt')
# 根據tensor的名字獲取到對應的tensor
# 之前保存模型的時候模型輸出保存爲output,":0"是保存模型參數時自動加上的,所以這裏也要寫上
output = sess.graph.get_tensor_by_name('output:0')
# 根據tensor的名字獲取到對應的tensor
# 之前保存模型的時候準確率計算保存爲accuracy,":0"是保存模型參數時自動加上的,所以這裏也要寫上
accuracy = sess.graph.get_tensor_by_name('accuracy:0')
# 之前保存模型的時候模型訓練保存爲train,注意這裏的train是operation不是tensor
train_step = sess.graph.get_operation_by_name('train')
# 把測試集喂到網絡中計算準確率
# x-input是模型數據的輸入,":0"是保存模型參數時自動加上的,所以這裏也要寫上
# y-input是模型標籤的輸入,":0"是保存模型參數時自動加上的,所以這裏也要寫上
print(sess.run(accuracy,feed_dict={'x-input:0':mnist.test.images,'y-input:0':mnist.test.labels}))
# 在原來模型的基礎上再訓練11個週期
for epoch in range(11):
for batch in range(n_batch):
# 獲取一個批次的數據和標籤
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
# 訓練模型
sess.run(train_step,feed_dict={'x-input:0':batch_xs,'y-input:0':batch_ys})
# 計算測試集準確率
acc = sess.run(accuracy,feed_dict={'x-input:0':mnist.test.images,'y-input:0':mnist.test.labels})
# 打印信息
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
Checkpoint_restore2
# 載入數據集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
# 定義批次大小
batch_size = 64
# 計算一共有多少個批次
n_batch = mnist.train.num_examples // batch_size
# 定義兩個placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
# 創建一個簡單的神經網絡,輸入層784個神經元,輸出層10個神經元
# 這裏的模型參數需要跟之前訓練好的模型參數一樣
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b)
# 計算準確率
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
# 定義saver用於載入模型
# max_to_keep=5,在指定路徑下最多保留5個模型,超過5個模型就會刪除老的模型
saver = tf.train.Saver(max_to_keep=5)
# 交叉熵代價函數
loss = tf.losses.softmax_cross_entropy(y,prediction)
# 使用Adam優化器
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
# 定義會話
with tf.Session() as sess:
# 變量初始化
sess.run(tf.global_variables_initializer())
# 計算測試集準確率
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
# 載入訓練好的參數
saver.restore(sess,'models/my_model.ckpt')
#saver.restore(sess,'models/my_model.ckpt')
# 再次計算測試集準確率
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
# 在原來模型的基礎上再訓練11個週期
for epoch in range(11):
for batch in range(n_batch):
# 獲取一個批次的數據和標籤
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
# 訓練模型
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
# 計算測試集準確率
acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
# 打印信息
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
# 保存模型,global_step可以用來表示模型的訓練次數或者訓練週期數
saver.save(sess,'models/my_model.ckpt',global_step=epoch)
Protocol_buffer_save
with tf.Session() as sess:
# 變量初始化
sess.run(init)
# 運行11個週期
for epoch in range(11):
for batch in range(n_batch):
# 獲取一個批次的數據和標籤
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
# 喂到模型中做訓練
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
# 每個週期計算一次測試集準確率
acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
# 打印信息
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
# 保存模型參數和結構,把變量變成常量
# output_node_names設置可以輸出的tensor
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output','accuracy'])
# 保存模型到目錄下的models文件夾中
with tf.gfile.FastGFile('pb_models/my_model.pb',mode='wb') as f:
f.write(output_graph_def.SerializeToString())
Protocol_buffer_restore
# 載入數據集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
# 載入模型
with tf.gfile.FastGFile('pb_models/my_model.pb', 'rb') as f:
# 創建一個圖
graph_def = tf.GraphDef()
# 把模型文件載入到圖中
graph_def.ParseFromString(f.read())
# 載入圖到當前環境中
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
# 根據tensor的名字獲取到對應的tensor
# 之前保存模型的時候模型輸出保存爲output,":0"是保存模型參數時自動加上的,所以這裏也要寫上
output = sess.graph.get_tensor_by_name('output:0')
# 根據tensor的名字獲取到對應的tensor
# 之前保存模型的時候準確率計算保存爲accuracy,":0"是保存模型參數時自動加上的,所以這裏也要寫上
accuracy = sess.graph.get_tensor_by_name('accuracy:0')
# 預測準確率
print(sess.run(accuracy,feed_dict={'x-input:0':mnist.test.images,'y-input:0':mnist.test.labels}))