tensorflow(四)實戰——基於全連接網絡的模型保存,讀取,tensorboard可視化

一、簡要說明

在模型保存爲model.ckpt時,生成了以下文件,其中的checkpoint文件、meta文件都能用來讀取變量

二、模型保存

在對話sess中使用tf.train.saver.save(sess,save_path)進行保存

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

network_shape=[1,5,10,1]
learning_rate=0.1
display_step=500
num_steps=1000

x_dot=np.linspace(1,2,300,dtype=np.float32)[:,np.newaxis]
y_dot=2*np.power(x_dot,3)+np.power(x_dot,2)+np.random.normal(0,0.5,x_dot.shape)

X_p=tf.placeholder(dtype=tf.float32,shape=[None,network_shape[0]],name="input")
Y_p=tf.placeholder(dtype=tf.float32,shape=[None,network_shape[-1]],name="output")


w={"w1":tf.Variable(tf.random_normal([network_shape[0],network_shape[1]]),name='w1'),
   "w2":tf.Variable(tf.random_normal([network_shape[1],network_shape[2]]),name='w2'),
   "out":tf.Variable(tf.random_normal([network_shape[2],network_shape[3]]),name='out')}

b={"b1":tf.Variable(tf.random_normal([network_shape[1]]),name='b1'),
   "b2": tf.Variable(tf.random_normal([network_shape[2]]),name='b2'),
   "out": tf.Variable(tf.random_normal([network_shape[3]]),name='out')}


def network(x):
    with tf.name_scope('layer_1'):
        layer1=tf.nn.relu(tf.matmul(x,w['w1'])+b['b1'])
    with tf.name_scope('layer_2'):
        layer2=tf.nn.relu(tf.matmul(layer1,w['w2'])+b['b2'])
    with tf.name_scope('out'):
        output=tf.matmul(layer2,w['out'])+b['out']
    return output
prediction=network(X_p)

loss = tf.reduce_mean(tf.reduce_sum(tf.square(Y_p-prediction), reduction_indices=[1]))

train_step=tf.train.AdamOptimizer(learning_rate).minimize(loss)
saver=tf.train.Saver()
init=tf.global_variables_initializer()
with tf.Session()as sess:
    sess.run(init)
    Plt=plt.figure().add_subplot(1, 1, 1)
    Plt.scatter(x_dot,y_dot)
    plt.ion()#使matplotlib的顯示模式轉換爲交互(interactive)模式。即使在腳本中遇到plt.show(),代碼還是會繼續執行
    plt.show()
    for i in range(1,num_steps+1):
        _,Loss=sess.run([train_step,loss], feed_dict={X_p: x_dot, Y_p: y_dot})
        if i%display_step ==0 or i ==1:
            print("echo : ",i,"loss = ",Loss)
            prediction_value=sess.run(prediction,feed_dict={X_p:x_dot})#shape=(300,1)
            if i !=1:
                Plt.lines.remove(lines[0])#刪去上次畫的圖
            # try:
            #     Plt.lines.remove(lines[0])
            # except Exception:
            #     pass
            lines=Plt.plot(x_dot,prediction_value)#
            plt.pause(1)# 爲防止matplotlib畫圖過快,畫完圖後自動關閉圖像窗口
    saver.save(sess=sess,save_path='./ckpt_files/model.ckpt')
    tf.summary.FileWriter('./log',tf.get_default_graph())

    # plt.waitforbuttonpress()#使最後一張圖打開狀態,不馬上結束程序運行

三、可視化

在定義命名空間時,使用with tf.name_scope('namescope'): 

保存events文件時,使用tf.summary.FileWriter(log_dir,tf.get_default_graph())

獲得瀏覽器地址時,使用tensorboard --logdir XXX

四、讀取保存的ckpt文件

有多種方法可以restore保存的變量的數據:

方法一:使用變量名獲得變量

 需要知道模型在訓練的時候是如何定義的,在取出時也定義一個同樣大小類型的變量,restore之後run變量

方法二:使用meta圖文件

可對圖進行操作,restore之後利用<op_name>:<output_index>取出tensor,run這個tensor就能獲得變量

方法三:使用checkpoint文件

可reader這個checkpoint文件,在restore之後通過tensor名稱獲得變量,這個方法可以獲得檢查點中所有的變量名

##########################模型的恢復(一):利用變量名############
import tensorflow as tf
network_shape=[1,5,10,1]
date=tf.Variable(initial_value=tf.random_normal([network_shape[0],network_shape[1]]),dtype=tf.float32,name='w1')
saver=tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess=sess,save_path='./ckpt_files/model.ckpt')
    da=sess.run(date)
    print(da)


##########################模型的恢復(二):利用meta文件############
# import tensorflow as tf
# saver=tf.train.import_meta_graph(meta_graph_or_file='./ckpt_files/model.ckpt.meta')
# with tf.Session() as sess:
#     saver.restore(sess,save_path='./ckpt_files/model.ckpt')
#     graph=tf.get_default_graph()
#     da=graph.get_tensor_by_name(name='w1:0')# Tensor names must be of the form "<op_name>:<output_index>"
#     date=sess.run(da)
#     print(date)

##########################模型的恢復(三):利用checkpoint文件############
# from tensorflow.python import pywrap_tensorflow
# checkpoint_path = './ckpt_files/model.ckpt'
# reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# var_to_shape_map = reader.get_variable_to_shape_map()
# print(reader.get_tensor('w1'))
# print(var_to_shape_map)
# for key in var_to_shape_map:
#     print("tensor_name: ", key)
#     print(reader.get_tensor(key))

 

參考鏈接:

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章