問題
ValueError: Input 0 of node bn/cond/ReadVariableOp/Switch was passed float from bn/moving_mean:0 incompatible with expected resource.解決辦法:
https://github.com/keras-team/keras/issues/11032#issuecomment-429989228
code
#! -*- coding: utf-8 -*-
from tensorflow.python.framework import graph_util, graph_io
from tensorflow.python.platform import gfile
from tensorflow import keras as k
import tensorflow as tf
def freeze_graph(graph, session, save_root, save_name, keep_var_name=None, output_names=None, clear_devices=True):
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_name or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
if clear_devices:
for node in graphdef_inf.node:
node.device = ""
graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output_names, freeze_var_names)
graph_io.write_graph(graphdef_frozen, save_root, save_name, as_text=False)
def convert(model_path):
tf.keras.backend.set_learning_phase(0)
model = k.models.load_model(model_path)
session = tf.keras.backend.get_session()
freeze_graph(session.graph,
session,
output_names=[out.op.name for out in model.outputs],
save_root='./models/', save_name='model.pb')
def show_graph(model_path):
with tf.Session() as sess:
with gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
writer = tf.summary.FileWriter('./logs/')
writer.add_graph(sess.graph)
writer.flush()
writer.close()
if __name__ == '__main__':
convert('./models/model.h5')
show_graph('./models/model.pb')