免責聲明:本文僅代表個人觀點,如有錯誤,請讀者自己鑑別;如果本文不小心含有別人的原創內容,請聯繫我刪除;本人心血製作,若轉載請註明出處
1、保存模型爲.ckpt文件
saver = tf.train.Saver()
save_path = saver.save(sess, model_path)# 保存模型 其中model_path爲模型保存的文件
保存後的模型有四個文件 checkpoint、SFCN.ckpt.data-00000-of-00001、SFCN.ckpt.index、SFCN.ckpt.meta
2、保存event, event爲事件的保存路徑
summary_writer = tf.summary.FileWriter(event, graph=sess.graph)
在命令窗口中輸入 tensorboard --logdir==事件路徑
在瀏覽器中可以可視化tensorboard,可以可視化tensorboard,查看圖中有哪些節點
3、將ckpt文件轉化爲.pb文件,注意output_node_names = "Placeholder,Placeholder_2,keep_probabilty,conv2d_transpose_3" 這一行要寫入所有點節點,注意節點之間不可以加空格
import tensorflow as tf from tensorflow.python.framework import graph_util def freeze_graph(input_checkpoint, output_graph): ''' :param input_checkpoint: :param output_graph: PB模型保存路徑 :return: ''' # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點 # 直接用最後輸出的節點,可以在tensorboard中查找到,tensorboard只能在linux中使用 output_node_names = "Placeholder,Placeholder_2,keep_probabilty,conv2d_transpose_3" saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) graph = tf.get_default_graph() # 獲得默認的圖 input_graph_def = graph.as_graph_def() # 返回一個序列化的圖代表當前的圖 with tf.Session() as sess: saver.restore(sess, input_checkpoint) # 恢復圖並得到數據 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,將變量值固定 sess=sess, input_graph_def=input_graph_def, # 等於:sess.graph_def output_node_names=output_node_names.split(",")) # 如果有多個輸出節點,以逗號隔開 with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型 f.write(output_graph_def.SerializeToString()) # 序列化輸出 print("%d ops in the final graph." % len(output_graph_def.node)) # 得到當前圖有幾個操作節點 input_checkpoint = "./checkpoint/SFCN.ckpt" # 輸入的ckpt文件位置 output_graph = "node.pb" # 輸出節點的文件名 freeze_graph(input_checkpoint, output_graph)
4、調用.pb文件,注意y_in = sess.graph.get_tensor_by_name("Placeholder:0") images_in = sess.graph.get_tensor_by_name("Placeholder_2:0") keep_probability_in = sess.graph.get_tensor_by_name("keep_probabilty:0") logits_out = sess.graph.get_tensor_by_name("conv2d_transpose_3:0")
這四個節點要與第3步中的節點保持一致,但是要在後面加入“:0”,如第3步中節點爲“Placeholder”,第4步要寫爲“Placeholder:0”
global graph graph = tf.get_default_graph() with graph.as_default(): output_graph_def = tf.GraphDef() with open(model_path, "rb") as f: output_graph_def.ParseFromString(f.read()) tf.import_graph_def(output_graph_def, name="") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) y_in = sess.graph.get_tensor_by_name("Placeholder:0") images_in = sess.graph.get_tensor_by_name("Placeholder_2:0") keep_probability_in = sess.graph.get_tensor_by_name("keep_probabilty:0") logits_out = sess.graph.get_tensor_by_name("conv2d_transpose_3:0") batch_size = 1 testTime = 0 predictLabel = tf.zeros(trainLabel.shape) predictLabel = sess.run(predictLabel) for i in range(0, smallImage): realbatch_array, real_labels, real_index = getNext_batch(trainData, trainLabel, trainIndex, i) testStart = time.time() yy = sess.run(logits_out, feed_dict={images_in: realbatch_array, y_in: real_labels, keep_probability_in: 1.0}) predictLabel[i, ...] = yy testEnd = time.time() testTime1 = testEnd - testStart testTime += testTime1