從Tensorflow模型文件中解析並顯示網絡結構圖(pb模型篇) 1 Tensor對象與Operation對象 2 提取pb文件中的網絡結構圖 3 繪製網絡結構 4 測試模型顯示 5 源碼地址

最近看到一個巨牛的人工智能教程,分享一下給大家。教程不僅是零基礎,通俗易懂,而且非常風趣幽默,像看小說一樣!覺得太牛了,所以分享給大家。平時碎片時間可以當小說看,【點這裏可以去膜拜一下大神的“小說”】

Tensorflow官方提供的Tensorboard可以可視化神經網絡結構圖,但是說實話,我幾乎從來不用。主要是因爲Tensorboard中查看到的圖結構太混亂了,包含了網絡中所有的計算節點(讀取數據節點、網絡節點、loss計算節點等等)。更可怕的是,如果一個計算節點是由多個基礎計算(如加減乘除等)構成,那麼在Tensorboard中會將基礎計算節點顯示而不是作爲一個整體顯示(典型的如Squeeze計算節點)。最近爲了排查網絡結構BUG花費一週時間,因此,狠下心來決定自己寫一個工具,將Tensorflow中的圖以最簡單的方式顯示最關鍵的網絡結構。

1 Tensor對象與Operation對象

Tensorflow中,Tensor對象主要用於存儲數據如常量和變量(訓練參數),Operation對象是計算節點,如卷積計算、反捲積計算、ReLU等等。每一個Operation對象均有輸入和輸出Tensor,同理,每個Tensor對象均有對應生成該Tensor的Operation對象和使用該Tensor對象作爲輸入的Operation對象。Tensor和Operation對象內均有相關屬性和函數來獲取其關聯的Operation和Tensor對象,相關屬性如下所示。

Tensor對象的op屬性指向生成該Tensor的Operation對象。
Tensor對象的consumers()函數獲取使用該Tensor對象作爲輸入的Operation對象。
Operation對象的inputs屬性指向該計算節點的輸入Tensor對象。
Operation對象的outputs屬性執行該計算節點的輸出Tensor對象。

如下圖所示的網絡結構中,調用Tensor_2對象的consumers()函數,返回的是[op_1,op_2]Tensor_3的op屬性指向的是op_1op_1的inputs屬性指向的是[Tensor_1,Tensor_2]op_1的output屬性指向的是[Tensor_3]

有了Tensor與Operation對應在圖中的關聯關係,就可以將網絡結構給畫出來。

2 提取pb文件中的網絡結構圖

pb文件是將模型參數固化到圖文件中,併合並了一些基礎計算和刪除了反向傳播相關計算得到的protobuf協議文件。如果讀者還不懂如何將CKPT模型文件轉pb文件,請參考我另一篇文章《 Tensorflow MobileNet移植到Android》的第1節部分。有了pb模型文件後,接下來是加載模型,加載pb模型示例代碼如下所示。

def read_graph_from_pb(tf_model_path ,input_names,output_name):  
    with open(tf_model_path, 'rb') as f:
        serialized = f.read() 
    tf.reset_default_graph()
    gdef = tf.GraphDef()
    gdef.ParseFromString(serialized) 
    with tf.Graph().as_default() as g:
        tf.import_graph_def(gdef, name='') 
    
    with tf.Session(graph=g) as sess: 
        OPS=get_ops_from_pb(g,input_names,output_name)
    return OPS

其中,倒數第2行調用到的函數get_ops_from_pb()用於獲取網絡結構圖中指定輸入節點和指定輸出節點之間的計算節點。之所以要指定輸入和輸出,是爲了將輸入之前的計算節點(如加載數據隊列等相關計算節點)和輸出之後的計算節點(如計算loss等相關計算節點)去除,免得礙眼。函數get_ops_from_pb()實現代碼如下。

def get_ops_from_pb(graph,input_names,output_name,save_ori_network=True):
    if save_ori_network:
        with open('ori_network.txt','w+') as w: 
            OPS=graph.get_operations()
            for op in OPS:
                txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
                w.write(txt+'\n') 
    inputs_tf = [graph.get_tensor_by_name(input_name) for input_name in input_names]
    output_tf =graph.get_tensor_by_name(output_name) 
    OPS =get_ops_from_inputs_outputs(graph, inputs_tf,[output_tf] ) 
    with open('network.txt','w+') as w: 
        for op in OPS:
            txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
            w.write(txt+'\n') 
    OPS = sort_ops(OPS)
    OPS = merge_layers(OPS)
    return OPS

在裁剪網絡結構(即只保留input_names和output_name之間節點)之前,先將原始的網絡結構寫入到ori_network.txt中,文件中,每一行寫入:輸入Tensor---->op---->輸出Tensor。接下來調用函數get_ops_from_inputs_outputs獲取指定節點之間的節點。並調用sort_ops函數對所有的節點排序,以保證被依賴的節點總是出現在相關節點之前。最後調用merge_layers函數,將一些可以合併的計算合併成一個獨立的節點,例如,Squeeze計算相關節點合併成一個單獨的Squeeze節點,又如const-->identity兩個計算節點可以直接忽略(即刪除)。

注意:篇幅有限,這裏不再將函數get_ops_from_inputs_outputssort_opsmerge_layers貼出,相關代碼請前往文尾提供的源碼地址中閱讀。

3 繪製網絡結構

考慮到SVG繪製圖形的簡單易用優點,將排好序的網絡計算節點和相關Tensor對象數據以Javascript字符串的形式寫入到HTML中,使用<line>標籤繪製箭頭,使用<rect>標籤繪製矩形,使用<ellipse>標籤繪製橢圓,使用<text>標籤顯示文字。繪製類似於如下所示圖像

注意:篇幅有限,這裏不再介紹Javascript代碼解析模型結構和SVG顯示相關的原理,相關代碼請前往文尾提供的源碼地址中閱讀。

4 測試模型顯示

《MobileNet V1官方預訓練模型的使用》文中介紹的MobileNet V1網絡結構爲例,下載MobileNet_v1_1.0_192文件並壓縮後,得到mobilenet_v1_1.0_192_frozen.pb文件。我們還需要知道mobilenet_v1_1.0_192_frozen.pb模型對應的輸入和輸出Tensor對象的名稱,好在MobileNet_v1_1.0_192壓縮包中包含文件mobilenet_v1_1.0_192_info.txt。通過該文件可知,輸入Tensor的名稱爲:input:0,輸出Tensor名稱爲:MobilenetV1/Predictions/Reshape_1:0。有了這些信息後,調用函數read_graph_from_pb得到靜態圖的節點列表對象ops,調用函數gen_graph(ops,"save/path/graph.html")後,在目錄save/path中得到graph.html文件,打開graph.html後,顯示結果如下。

顯示網絡結構分兩種模式:合併模式和展開模式,分別如下圖所示。

5 源碼地址

https://github.com/huachao1001/CNNGraph

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章