一、簡要說明
在模型保存爲model.ckpt時,生成了以下文件,其中的checkpoint文件、meta文件都能用來讀取變量
二、模型保存
在對話sess中使用tf.train.saver.save(sess,save_path)進行保存
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
network_shape=[1,5,10,1]
learning_rate=0.1
display_step=500
num_steps=1000
x_dot=np.linspace(1,2,300,dtype=np.float32)[:,np.newaxis]
y_dot=2*np.power(x_dot,3)+np.power(x_dot,2)+np.random.normal(0,0.5,x_dot.shape)
X_p=tf.placeholder(dtype=tf.float32,shape=[None,network_shape[0]],name="input")
Y_p=tf.placeholder(dtype=tf.float32,shape=[None,network_shape[-1]],name="output")
w={"w1":tf.Variable(tf.random_normal([network_shape[0],network_shape[1]]),name='w1'),
"w2":tf.Variable(tf.random_normal([network_shape[1],network_shape[2]]),name='w2'),
"out":tf.Variable(tf.random_normal([network_shape[2],network_shape[3]]),name='out')}
b={"b1":tf.Variable(tf.random_normal([network_shape[1]]),name='b1'),
"b2": tf.Variable(tf.random_normal([network_shape[2]]),name='b2'),
"out": tf.Variable(tf.random_normal([network_shape[3]]),name='out')}
def network(x):
with tf.name_scope('layer_1'):
layer1=tf.nn.relu(tf.matmul(x,w['w1'])+b['b1'])
with tf.name_scope('layer_2'):
layer2=tf.nn.relu(tf.matmul(layer1,w['w2'])+b['b2'])
with tf.name_scope('out'):
output=tf.matmul(layer2,w['out'])+b['out']
return output
prediction=network(X_p)
loss = tf.reduce_mean(tf.reduce_sum(tf.square(Y_p-prediction), reduction_indices=[1]))
train_step=tf.train.AdamOptimizer(learning_rate).minimize(loss)
saver=tf.train.Saver()
init=tf.global_variables_initializer()
with tf.Session()as sess:
sess.run(init)
Plt=plt.figure().add_subplot(1, 1, 1)
Plt.scatter(x_dot,y_dot)
plt.ion()#使matplotlib的顯示模式轉換爲交互(interactive)模式。即使在腳本中遇到plt.show(),代碼還是會繼續執行
plt.show()
for i in range(1,num_steps+1):
_,Loss=sess.run([train_step,loss], feed_dict={X_p: x_dot, Y_p: y_dot})
if i%display_step ==0 or i ==1:
print("echo : ",i,"loss = ",Loss)
prediction_value=sess.run(prediction,feed_dict={X_p:x_dot})#shape=(300,1)
if i !=1:
Plt.lines.remove(lines[0])#刪去上次畫的圖
# try:
# Plt.lines.remove(lines[0])
# except Exception:
# pass
lines=Plt.plot(x_dot,prediction_value)#
plt.pause(1)# 爲防止matplotlib畫圖過快,畫完圖後自動關閉圖像窗口
saver.save(sess=sess,save_path='./ckpt_files/model.ckpt')
tf.summary.FileWriter('./log',tf.get_default_graph())
# plt.waitforbuttonpress()#使最後一張圖打開狀態,不馬上結束程序運行
三、可視化
在定義命名空間時,使用with tf.name_scope('namescope'):
保存events文件時,使用tf.summary.FileWriter(log_dir,tf.get_default_graph())
獲得瀏覽器地址時,使用tensorboard --logdir XXX
四、讀取保存的ckpt文件
有多種方法可以restore保存的變量的數據:
方法一:使用變量名獲得變量
需要知道模型在訓練的時候是如何定義的,在取出時也定義一個同樣大小類型的變量,restore之後run變量
方法二:使用meta圖文件
可對圖進行操作,restore之後利用<op_name>:<output_index>取出tensor,run這個tensor就能獲得變量
方法三:使用checkpoint文件
可reader這個checkpoint文件,在restore之後通過tensor名稱獲得變量,這個方法可以獲得檢查點中所有的變量名
##########################模型的恢復(一):利用變量名############
import tensorflow as tf
network_shape=[1,5,10,1]
date=tf.Variable(initial_value=tf.random_normal([network_shape[0],network_shape[1]]),dtype=tf.float32,name='w1')
saver=tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess=sess,save_path='./ckpt_files/model.ckpt')
da=sess.run(date)
print(da)
##########################模型的恢復(二):利用meta文件############
# import tensorflow as tf
# saver=tf.train.import_meta_graph(meta_graph_or_file='./ckpt_files/model.ckpt.meta')
# with tf.Session() as sess:
# saver.restore(sess,save_path='./ckpt_files/model.ckpt')
# graph=tf.get_default_graph()
# da=graph.get_tensor_by_name(name='w1:0')# Tensor names must be of the form "<op_name>:<output_index>"
# date=sess.run(da)
# print(date)
##########################模型的恢復(三):利用checkpoint文件############
# from tensorflow.python import pywrap_tensorflow
# checkpoint_path = './ckpt_files/model.ckpt'
# reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# var_to_shape_map = reader.get_variable_to_shape_map()
# print(reader.get_tensor('w1'))
# print(var_to_shape_map)
# for key in var_to_shape_map:
# print("tensor_name: ", key)
# print(reader.get_tensor(key))
參考鏈接: