tensorflow之打印網絡節點(名稱)

1 測試代碼: 

$ cat export_nodename.py

#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

#coding:utf-8 
import tensorflow as tf
import os

model_dir = 'work/CNN/CNN2/training'
model_name = 'dnn.pb'

# 讀取並創建一個圖graph來存放Google訓練好的模型(函數)
def create_graph():
    with tf.gfile.GFile(os.path.join(
            model_dir, model_name), 'rb') as f:
        # 使用tf.GraphDef()定義一個空的Graph
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        # Imports the graph from graph_def into the current default Graph.
        tf.import_graph_def(graph_def, name='')

# 創建graph
create_graph()

tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
    print(tensor_name)
~                         

2 運行代碼查看結果

$ python export_nodename.py 
wav_data
decoded_sample_data
AudioSpectrogram
Mfcc
Reshape/shape
Reshape
Reshape_1/shape
Reshape_1
Variable
Variable/read
Variable_1
Variable_1/read
Conv2D
add
Relu
Reshape_2/shape
Reshape_2
Variable_2
Variable_2/read
Variable_3
Variable_3/read
MatMul
add_1
Variable_4
Variable_4/read
Variable_5
Variable_5/read
MatMul_1
add_2
Variable_6
Variable_6/read
Variable_7
Variable_7/read
MatMul_2
add_3
labels_softmax

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章