最近真的項目很緊,996模式開啓了多日,在完成項目一小步後,總算有時間進行一些梳理,也是對tensorflow有了更多的認識。
之前在學校實驗室中,其實並不涉及太多tensorflow c++端的接口,無論是訓練還是做inference,基本上都是python,相對來說還是比較簡單的,但是c++還是會比python多很多步驟,第一個遇到的問題其實就是,我怎麼使用python訓練好的ckpt模型呢?
下面就對這個步驟進行梳理
1. 確認訓練的模型準確
換句話說,訓練出來的模型都要先進行python版本的inference,確認訓練的模型無誤。給模型輸入設置對應的佔位符,爲什麼?
因爲保存的ckpt模型只有你訓練出來的網絡參數,沒有給輸入留位置,換句話說你的tensor進不去圖裏,所以要給模型設置對應的佔位符,這樣你的輸入纔有位置,你的輸入也才能傳到網絡中去。
Input_img = tf.placeholder(tf.float32, [None, height, width, channel],
name="input_img")
通常情況下,定義的輸入都是這樣的操作,爲了定義多batchsize,所以把tensor的第一個維度設置成None,這個應該還挺好理解的。
下面給就是定義你的inference操作,因爲LZ的inference操作還比較麻煩,有些可能網絡直接sess run一下就可以了,在inference結束後,需要將你的變量和圖輸出爲pb型,以備後續給tensorflow c++接口使用,opencv好像也行,這個就沒仔細研究過了。
使用graph_io接口保存成pb模型,然後也重新保存下ckpt。
graph_io.write_graph(R_sess.graph, './model_pb', "name.pb", as_text=False)
saver.save(R_sess, os.path.join('./model_pb', 'name.ckpt'))
這裏也是總結一番吧,通常我們在訓練的時候,會保存比較好的模型,也就是常說的設置checkpoint,在訓練過程中,在保存的文件夾中會有四個文件,checkpoint, ckpt-data, ckpt-meta, ckpt-index,其中,checkpoint會保存你模型名字等,ckpt-data其實就是變量的具體數值,ckpt-meta就是保存了圖的結構,換句話說就是你的網絡結構,ckpt-index應該就是 保存了變量值和網絡結構變量名的對應關係(這個不是很確定,因爲一般好像用不到)
2. 確認你的輸入輸出節點
這裏定義了一個函數,用來使用tensorboard來看pb中圖的結構,從而也可以確認網絡的輸入輸出節點,只有確認了輸入輸出節點,才能在後續固定參數,優化對應的圖結構。
def import_to_tensorboard(model_dir, log_dir, frozen_graph):
"""View an imported protobuf model (`.pb` file) as a graph in Tensorboard.
Args:
model_dir: The location of the protobuf (`pb`) model to visualize
log_dir: The location for the Tensorboard log to begin visualization from.
frozen_graph: frozen flag, if frozen set true, else set false.
Usage:
Call this function with your model location and desired log directory.
Launch Tensorboard by pointing it to the log directory.
View your imported `.pb` model as a graph.
"""
with session.Session(graph=ops.Graph()) as sess:
with gfile.FastGFile(model_dir, "rb") as f:
graph_def = graph_pb2.GraphDef()
data = f.read()
if frozen_graph:
graph_def.ParseFromString(data)
else:
text_format.Merge(data.decode("utf-8"), graph_def)
importer.import_graph_def(graph_def)
pb_visual_writer = summary.FileWriter(log_dir)
pb_visual_writer.add_graph(sess.graph)
print("Model Imported. Visualize by running: "
"> tensorboard --logdir={}".format(log_dir))
3. freeze graph
這個就是直接使用tensorflow中現成的函數進行參數固定,在tensorflow的源碼tensorflow/python/tools的文件夾中。
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Converts checkpoint variables into Const ops in a standalone GraphDef file.
This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
variable values stored in a checkpoint file, and output a GraphDef with all of
the variable ops converted into const ops containing the values of the
variables.
It's useful to do this when we need to load a single file in C++, especially in
environments like mobile or embedded where we may not have access to the
RestoreTensor ops and file loading calls that they rely on.
An example of command-line usage is:
bazel build tensorflow/python/tools:freeze_graph && \
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=some_graph_def.pb \
--input_checkpoint=model.ckpt-8361242 \
--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
You can also look at freeze_graph_test.py for an example of how to use it.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import re
import sys
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
def _has_no_variables(sess):
"""Determines if the graph has any variables.
Args:
sess: TensorFlow Session.
Returns:
Bool.
"""
for op in sess.graph.get_operations():
if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
return False
return True
def freeze_graph_with_def_protos(input_graph_def,
input_saver_def,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_whitelist="",
variable_names_blacklist="",
input_meta_graph_def=None,
input_saved_model_dir=None,
saved_model_tags=None,
checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants.
Args:
input_graph_def: A `GraphDef`.
input_saver_def: A `SaverDef` (optional).
input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
priority. Typically the result of `Saver.save()` or that of
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
V1/V2.
output_node_names: The name(s) of the output nodes, comma separated.
restore_op_name: Unused.
filename_tensor_name: Unused.
output_graph: String where to write the frozen `GraphDef`.
clear_devices: A Bool whether to remove device specifications.
initializer_nodes: Comma separated string of initializer nodes to run before
freezing.
variable_names_whitelist: The set of variable names to convert (optional, by
default, all variables are converted).
variable_names_blacklist: The set of variable names to omit converting
to constants (optional).
input_meta_graph_def: A `MetaGraphDef` (optional),
input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
and variables (optional).
saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
load, in string format (optional).
checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
or saver_pb2.SaverDef.V2)
Returns:
Location of the output_graph_def.
"""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if (not input_saved_model_dir and
not checkpoint_management.checkpoint_exists(input_checkpoint)):
raise ValueError("Input checkpoint '" + input_checkpoint +
"' doesn't exist!")
if not output_node_names:
raise ValueError(
"You need to supply the name of a node to --output_node_names.")
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
if input_meta_graph_def:
for node in input_meta_graph_def.graph_def.node:
node.device = ""
elif input_graph_def:
for node in input_graph_def.node:
node.device = ""
if input_graph_def:
_ = importer.import_graph_def(input_graph_def, name="")
with session.Session() as sess:
if input_saver_def:
saver = saver_lib.Saver(
saver_def=input_saver_def, write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
elif input_meta_graph_def:
restorer = saver_lib.import_meta_graph(
input_meta_graph_def, clear_devices=True)
restorer.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.replace(" ", "").split(","))
elif input_saved_model_dir:
if saved_model_tags is None:
saved_model_tags = []
loader.load(sess, saved_model_tags, input_saved_model_dir)
else:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
# List of all partition variables. Because the condition is heuristic
# based, the list could include false positives.
all_parition_variable_names = [
tensor.name.split(":")[0]
for op in sess.graph.get_operations()
for tensor in op.values()
if re.search(r"/part_\d+/", tensor.name)
]
has_partition_var = False
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ":0")
if any(key in name for name in all_parition_variable_names):
has_partition_var = True
except KeyError:
# This tensor doesn't exist in the graph (for example it's
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
try:
saver = saver_lib.Saver(
var_list=var_list, write_version=checkpoint_version)
except TypeError as e:
# `var_list` is required to be a map of variable names to Variable
# tensors. Partition variables are Identity tensors that cannot be
# handled by Saver.
if has_partition_var:
raise ValueError(
"Models containing partition variables cannot be converted "
"from checkpoint files. Please pass in a SavedModel using "
"the flag --input_saved_model_dir.")
# Models that have been frozen previously do not contain Variables.
elif _has_no_variables(sess):
raise ValueError(
"No variables were found in this model. It is likely the model "
"was frozen previously. You cannot freeze a graph twice.")
return 0
else:
raise e
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.replace(" ", "").split(","))
variable_names_whitelist = (
variable_names_whitelist.replace(" ", "").split(",")
if variable_names_whitelist else None)
variable_names_blacklist = (
variable_names_blacklist.replace(" ", "").split(",")
if variable_names_blacklist else None)
if input_meta_graph_def:
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_meta_graph_def.graph_def,
output_node_names.replace(" ", "").split(","),
variable_names_whitelist=variable_names_whitelist,
variable_names_blacklist=variable_names_blacklist)
else:
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.replace(" ", "").split(","),
variable_names_whitelist=variable_names_whitelist,
variable_names_blacklist=variable_names_blacklist)
# Write GraphDef to file if output path has been given.
if output_graph:
with gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
return output_graph_def
def _parse_input_graph_proto(input_graph, input_binary):
"""Parses input tensorflow graph into GraphDef proto."""
if not gfile.Exists(input_graph):
raise IOError("Input graph file '" + input_graph + "' does not exist!")
input_graph_def = graph_pb2.GraphDef()
mode = "rb" if input_binary else "r"
with gfile.GFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)
return input_graph_def
def _parse_input_meta_graph_proto(input_graph, input_binary):
"""Parses input tensorflow graph into MetaGraphDef proto."""
if not gfile.Exists(input_graph):
raise IOError("Input meta graph file '" + input_graph + "' does not exist!")
input_meta_graph_def = MetaGraphDef()
mode = "rb" if input_binary else "r"
with gfile.GFile(input_graph, mode) as f:
if input_binary:
input_meta_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_meta_graph_def)
print("Loaded meta graph file '" + input_graph)
return input_meta_graph_def
def _parse_input_saver_proto(input_saver, input_binary):
"""Parses input tensorflow Saver into SaverDef proto."""
if not gfile.Exists(input_saver):
raise IOError("Input saver file '" + input_saver + "' does not exist!")
mode = "rb" if input_binary else "r"
with gfile.GFile(input_saver, mode) as f:
saver_def = saver_pb2.SaverDef()
if input_binary:
saver_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), saver_def)
return saver_def
def freeze_graph(input_graph,
input_saver,
input_binary,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_whitelist="",
variable_names_blacklist="",
input_meta_graph=None,
input_saved_model_dir=None,
saved_model_tags=tag_constants.SERVING,
checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants.
Args:
input_graph: A `GraphDef` file to load.
input_saver: A TensorFlow Saver file.
input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
priority. Typically the result of `Saver.save()` or that of
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
V1/V2.
output_node_names: The name(s) of the output nodes, comma separated.
restore_op_name: Unused.
filename_tensor_name: Unused.
output_graph: String where to write the frozen `GraphDef`.
clear_devices: A Bool whether to remove device specifications.
initializer_nodes: Comma separated list of initializer nodes to run before
freezing.
variable_names_whitelist: The set of variable names to convert (optional, by
default, all variables are converted),
variable_names_blacklist: The set of variable names to omit converting
to constants (optional).
input_meta_graph: A `MetaGraphDef` file to load (optional).
input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
variables (optional).
saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
load, in string format.
checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
or saver_pb2.SaverDef.V2).
Returns:
String that is the location of frozen GraphDef.
"""
input_graph_def = None
if input_saved_model_dir:
input_graph_def = saved_model_utils.get_meta_graph_def(
input_saved_model_dir, saved_model_tags).graph_def
elif input_graph:
input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
input_meta_graph_def = None
if input_meta_graph:
input_meta_graph_def = _parse_input_meta_graph_proto(
input_meta_graph, input_binary)
input_saver_def = None
if input_saver:
input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
return freeze_graph_with_def_protos(
input_graph_def,
input_saver_def,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_whitelist,
variable_names_blacklist,
input_meta_graph_def,
input_saved_model_dir,
saved_model_tags.replace(" ", "").split(","),
checkpoint_version=checkpoint_version)
def main(unused_args, flags):
if flags.checkpoint_version == 1:
checkpoint_version = saver_pb2.SaverDef.V1
elif flags.checkpoint_version == 2:
checkpoint_version = saver_pb2.SaverDef.V2
else:
raise ValueError("Invalid checkpoint version (must be '1' or '2'): %d" %
flags.checkpoint_version)
freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
flags.input_checkpoint, flags.output_node_names,
flags.restore_op_name, flags.filename_tensor_name,
flags.output_graph, flags.clear_devices, flags.initializer_nodes,
flags.variable_names_whitelist, flags.variable_names_blacklist,
flags.input_meta_graph, flags.input_saved_model_dir,
flags.saved_model_tags, checkpoint_version)
def run_main():
"""Main function of freeze_graph."""
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--input_graph",
type=str,
default="",
help="TensorFlow \'GraphDef\' file to load.")
parser.add_argument(
"--input_saver",
type=str,
default="",
help="TensorFlow saver file to load.")
parser.add_argument(
"--input_checkpoint",
type=str,
default="",
help="TensorFlow variables file to load.")
parser.add_argument(
"--checkpoint_version",
type=int,
default=2,
help="Tensorflow variable file format")
parser.add_argument(
"--output_graph",
type=str,
default="",
help="Output \'GraphDef\' file name.")
parser.add_argument(
"--input_binary",
nargs="?",
const=True,
type="bool",
default=False,
help="Whether the input files are in binary format.")
parser.add_argument(
"--output_node_names",
type=str,
default="",
help="The name of the output nodes, comma separated.")
parser.add_argument(
"--restore_op_name",
type=str,
default="save/restore_all",
help="""\
The name of the master restore operator. Deprecated, unused by updated \
loading code.
""")
parser.add_argument(
"--filename_tensor_name",
type=str,
default="save/Const:0",
help="""\
The name of the tensor holding the save path. Deprecated, unused by \
updated loading code.
""")
parser.add_argument(
"--clear_devices",
nargs="?",
const=True,
type="bool",
default=True,
help="Whether to remove device specifications.")
parser.add_argument(
"--initializer_nodes",
type=str,
default="",
help="Comma separated list of initializer nodes to run before freezing.")
parser.add_argument(
"--variable_names_whitelist",
type=str,
default="",
help="""\
Comma separated list of variables to convert to constants. If specified, \
only those variables will be converted to constants.\
""")
parser.add_argument(
"--variable_names_blacklist",
type=str,
default="",
help="""\
Comma separated list of variables to skip converting to constants.\
""")
parser.add_argument(
"--input_meta_graph",
type=str,
default="",
help="TensorFlow \'MetaGraphDef\' file to load.")
parser.add_argument(
"--input_saved_model_dir",
type=str,
default="",
help="Path to the dir with TensorFlow \'SavedModel\' file and variables.")
parser.add_argument(
"--saved_model_tags",
type=str,
default="serve",
help="""\
Group of tag(s) of the MetaGraphDef to load, in string format,\
separated by \',\'. For tag-set contains multiple tags, all tags \
must be passed in.\
""")
flags, unparsed = parser.parse_known_args()
my_main = lambda unused_args: main(unused_args, flags)
app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
if __name__ == "__main__":
run_main()
這個函數中有很多參數,
- input_graph:你要下載的模型的圖文件,可以是pb或者meta類型的
- input_saver:就是下載tensorflow對應的saver文件,但是LZ看了好多例子,基本上這個都是可以省略的
- input_checkpoint:輸入你對應checkpoint的位置
- checkpoint_version:這是版本問題,默認爲2,
- output_graph:輸出圖的文件名稱
- input_binary:是否爲二進制的輸入,
- output_node_names: 輸出節點的名字,有多個輸出時用逗號分開
- restore_op_name:這個好像也不怎麼用,默認爲save/restore_all
- filename_tensorflow_name:這個也不用了
- clear_device:默認是true,清除訓練時指定的設備,如之前指定使用哪塊GPU
- initializer_nodes:可以來初始化對應節點,如果有多個節點還是使用逗號分開,這個用的也很少。
- variable_name_white:指定進行freeze的變量名單,多個也使用逗號分開,如果不指定,默認freeze全部變量
- variable_name_blacklist:不用進行freeze的變量名單,多個也使用逗號分開
- input_meta_graph:輸入meta對應地址
- input_saved_model_dir:save_model文件和變量地址
- saved_model_tags:要加載MetaGraph的標籤組,以字符串格式,如果存在多個標籤,使用逗號分開,如果標籤集包含多個標籤,則必須傳遞所有標籤。
其實看看很多參數,實際上用起來,最主要的格式應該是:
python freeze_graph.py --input_graph = path to you pb
--input_checkpoint = path to your ckpt
--output_graph = path to your output pb
--output_node_name = name of your output node
後面根據自己的需求進行對應的設置
4.optimize pb
其實在訓練的時候有些模型的參數對於inference階段是冗餘的,所以需要把對應的模型進行優化,主要是刪掉多餘的節點
def opt_freezed_pb(tmp_dir,
input_node_names,
output_node_names,
input_graph_name,
output_graph_name):
input_graph_path = os.path.join(tmp_dir, input_graph_name)
input_graph_def = graph_pb2.GraphDef()
with gfile.Open(input_graph_path, "rb") as f:
data = f.read()
input_graph_def.ParseFromString(data)
# text_format.Merge(f.read(), input_graph_def)
# print(get_node_name_list(input_graph_def))
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
input_node_names.split(","),
output_node_names.split(","), dtypes.float32.as_datatype_enum)
output_graph_filename = os.path.join(tmp_dir, output_graph_name)
f = gfile.FastGFile(output_graph_filename, "w")
f.write(output_graph_def.SerializeToString())
小夥伴可以通過tensorboard對比出具體的區別
5.優化後的pb模型檢驗
最後在python下,使用優化好的pb模型進行檢驗,看是否和原始ckpt模型inference結果一致。
6.最後當然是使用tensorflow c++接口進行測試啦!