tf pb 轉 tfjs 將固定大小的輸入 改爲任意輸入

訓練時使用固定大小, 方便編程實現和速度優化

部署時使用任意大小, 提高體驗

但是採用固定大小輸入做訓練, 部署時採用任意大小, 可能效果有點差別吧....

 

pb文件的網絡結構

 
  1. import tensorflow as tf

  2.  
  3. output_graph_def = tf.GraphDef()

  4. PB_PATH = r"./pb/feathers.pb"

  5.  
  6. with open(PB_PATH, "rb") as f:

  7. output_graph_def.ParseFromString(f.read())

  8. tf.import_graph_def(

  9. output_graph_def,

  10. name='', # 默認name爲import,類似scope

  11. )

  12.  
  13. with tf.Session() as sess:

  14. sess.run(tf.global_variables_initializer())

  15. tf.summary.FileWriter('./log/', sess.graph)

tensorboard.exe --logdir .

a862caf023a621bae890fa6ce0c12e091c1.jpg11a78c39ca187e16df29a689f51bbd42221.jpg

將固定輸入大小的pb文件,轉化爲任意輸入大小的tfjs格式 

 
  1. import tensorflow as tf

  2. from tensorflow.python.framework import graph_util

  3. import tensorflowjs as tfjs

  4.  
  5. sess = tf.Session()

  6. output_graph_def = tf.GraphDef()

  7. # feathers starry candy

  8. PB_PATH = r"./pb/candy.pb"

  9. TFJS_PATH = r'./tfjs/candy'

  10.  
  11. in_image = tf.placeholder(tf.float32, (None, None, None, 3), name='in_x')

  12. with open(PB_PATH, "rb") as f:

  13. output_graph_def.ParseFromString(f.read())

  14. tf.import_graph_def(

  15. output_graph_def,

  16. input_map={

  17. 'in_x:0': in_image

  18. },

  19. name='', # 默認name爲import,類似scope

  20. # return_elements=['generator/mul:0']

  21. )

  22. sess.run(tf.global_variables_initializer())

  23. output = sess.graph.get_tensor_by_name("generator/output:0")

  24.  
  25. with tf.Session() as sess:

  26. sess.run(tf.global_variables_initializer())

  27. constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['generator/output'])

  28. with tf.gfile.FastGFile("./pb/tmp.pb", mode='wb') as f:

  29. f.write(constant_graph.SerializeToString())

  30.  
  31. tfjs.converters.tf_saved_model_conversion_pb.convert_tf_frozen_model(

  32. "./pb/tmp.pb",

  33. 'generator/output',

  34. TFJS_PATH

  35. )

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章