Tensorflow框架實現中的“三”種圖

原文鏈接:https://zhuanlan.zhihu.com/p/31308381

圖(graph)是 tensorflow 用於表達計算任務的一個核心概念。從前端(python)描述神經網絡的結構,到後端在多機和分佈式系統上部署,到底層 Device(CPU、GPU、TPU)上運行,都是基於圖來完成。然而我在實際使用過程中遇到了三對API:

  1. tf.train.Saver()/saver.restore()
  2. export_meta_graph/Import_meta_graph
  3. tf.train.write_graph()/tf.Import_graph_def()

他們都是用於對圖的保存和恢復。同一個計算框架,爲什麼需要三對不同的API呢?他們保存/恢復的圖在使用時又有什麼區別呢?初學的時候,常常鬧不清楚他們的區別,以至常常寫出了錯誤的程序,經過一番研究,在本文中對Tensorflow中圍繞Graph的核心概念進行了總結。

Graph
首先介紹一下關於 Tensorflow 中 Graph 和它的序列化表示 Graph_def。在Tensorflow的官方文檔中,Graph 被定義爲“一些 Operation 和 Tensor 的集合”。例如我們表達如下的一個計算的 python代碼:

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
c = tf.placeholder(tf.float32)
d = a*b+c
e = d*2

就會生成相應的一張圖,在Tensorboard中看到的圖大概如下這樣。其中每一個圓圈表示一個Operation(輸入處爲Placeholder),橢圓到橢圓的邊爲Tensor,箭頭的指向表示了這張圖Operation 輸入輸出 Tensor 的傳遞關係。
在這裏插入圖片描述
這張圖所表達的數據流 與 python 代碼中所表達的計算是對應的關係(爲了稱呼方便,我們下面將這張由Python表達式所描述的數據流動關係叫做 Python Graph)。然而在真實的 Tensorflow 運行中,Python 構建的“圖”並不是啓動一個Session之後始終不變的東西。因爲Tensorflow在運行時,真實的計算會被下放到多CPU上,或者 GPU 等異構設備,或者ARM等上進行高性能/能效的計算。單純使用 Python 肯定是無法有效完成的。實際上,Tensorflow而是首先將 python 代碼所描繪的圖轉換(即“序列化”)成 Protocol Buffer,再通過 C/C++/CUDA 運行 Protocol Buffer 所定義的圖。(Protocol Buffer的介紹可以參考這篇文章學習:https://www.ibm.com/developerworks/cn/linux/l-cn-gpb/)

GraphDef
從 python Graph中序列化出來的圖就叫做 GraphDef(這是一種不嚴格的說法,先這樣進行理解)。而 GraphDef 又是由許多叫做 NodeDef 的 Protocol Buffer 組成。在概念上 NodeDef 與 (Python Graph 中的)Operation 相對應。如下就是 GraphDef 的 ProtoBuf,由許多node組成的圖表示。這是與上文 Python 圖對應的 GraphDef:

node {
  name: "Placeholder"     # 註釋:這是一個叫做 "Placeholder" 的node
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "Placeholder_1"     # 註釋:這是一個叫做 "Placeholder_1" 的node
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "mul"                 # 註釋:一個 Mul(乘法)操作
  op: "Mul"
  input: "Placeholder"        # 使用上面的node(即Placeholder和Placeholder_1)
  input: "Placeholder_1"      # 作爲這個Node的輸入
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}

以上三個 NodeDef 定義了兩個Placeholder和一個Multiply。Placeholder 通過 attr(attribute的縮寫)來定義數據類型和 Tensor 的形狀。Multiply通過 input 屬性定義了兩個placeholder作爲其輸入。無論是 Placeholder 還是 Multiply 都沒有關於輸出(output)的信息。其實 Tensorflow 中都是通過 Input 來定義 Node 之間的連接信息。

那麼既然 tf.Operation 的序列化 ProtoBuf 是 NodeDef,那麼 tf.Variable 呢?在這個 GraphDef 中只有網絡的連接信息,卻沒有任何 Variables呀?沒錯,Graphdef中不保存任何 Variable 的信息,所以如果我們從 graph_def 來構建圖並恢復訓練的話,是不能成功的。比如以下代碼:

with tf.Graph().as_default() as graph:
  tf.import_graph_def("graph_def_path")
  saver= tf.train.Saver()
  with tf.Session() as sess:
    tf.trainable_variables()

其中 tf.trainable_variables() 只會返回一個空的list。Tf.train.Saver() 也會報告 no variables to save。

然而,在實際線上 inference 中,通常就是使用 GraphDef。然而,GraphDef中連Variable都沒有,怎麼存儲weight呢?原來GraphDef 雖然不能保存 Variable,但可以保存 Constant 。通過 tf.constant 將 weight 直接存儲在 NodeDef 裏,tensorflow 1.3.0 版本也提供了一套叫做 freeze_graph 的工具來自動的將圖中的 Variable 替換成 constant 存儲在 GraphDef 裏面,並將該圖導出爲 Proto。可以查看以下鏈接獲取更多信息,

https://www.tensorflow.org/extend/tool_developers/

https://www.tensorflow.org/mobile/prepare_models

tf.train.write_graph()/tf.Import_graph_def() 就是用來進行 GraphDef 讀寫的API。那麼,我們怎麼才能從序列化的圖中,得到 Variables呢?這就要學習下一個重要概念,MetaGraph。

MetaGraph
Meta graph 的官方解釋是:一個Meta Graph 由一個計算圖和其相關的元數據構成。其包含了用於繼續訓練,實施評估和(在已訓練好的的圖上)做前向推斷的信息。(A MetaGraph consists of both a computational graph and its associated metadata. A MetaGraph contains the information required to continue training, perform evaluation, or run inference on a previously trained graph. From https://www.tensorflow.org/versions/r1.1/programmers_guide/

這一段看的雲裏霧裏,不過這篇文章(https://www.tensorflow.org/versions/r1.1/programmers_guide/meta_graph)進一步解釋說,Meta Graph在具體實現上就是一個MetaGraphDef (同樣是由 Protocol Buffer來定義的)。其包含了四種主要的信息,根據Tensorflow官網,這四種 Protobuf 分別是

  1. MetaInfoDef,存一些元信息(比如版本和其他用戶信息)
  2. GraphDef, MetaGraph的核心內容之一,我們剛剛介紹過
  3. SaverDef,圖的Saver信息(比如最多同時保存的check-point數量,需保存的Tensor名字等,但並不保存Tensor中的實際內容)
  4. CollectionDef 任何需要特殊注意的 Python 對象,需要特殊的標註以方便import_meta_graph 後取回。(比如“train_op”,"prediction"等等)

在以上四種 ProtoBuf 裏面,1 和 3 都比較容易理解,2 剛剛總結過。這裏特別要講一下Collection(CollectionDef是對應的ProtoBuf)。

Tensorflow中並沒有一個官方的定義說 collection 是什麼。簡單的理解,它就是爲了方別用戶對圖中的操作和變量進行管理,而創建的一個概念。它可以說是一種“集合”,通過一個 key(string類型)來對一組 Python 對象進行命名的集合。這個key既可以是tensorflow在內部定義的一些key,也可以是用戶自己定義的名字(string)。

Tensorflow 內部定義了許多標準 Key,全部定義在了 tf.GraphKeys 這個類中。其中有一些常用的,tf.GraphKeys.TRAINABLE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES 等等。tf.trainable_variables() 與 tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 是等價的;tf.global_variables() 與 tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 是等價的。

對於用戶定義的 key,我們舉一個例子。例如:

pred = model_network(X)
loss=tf.reduce_mean(…, pred, …)
train_op=tf.train.AdamOptimizer(lr).minimize(loss)

這樣一段 Tensorflow程序,用戶希望特別關注 pred, loss train_op 這幾個操作,那麼就可以使用如下代碼,將這幾個變量加入到 collection 中去。(假設我們將其命名爲 “training_collection”)

tf.add_to_collection("training_collection", pred)
tf.add_to_collection("training_collection", loss)
tf.add_to_collection("training_collection", train_op)

並且可以通過 Train_collect = tf.get_collection(“training_collection”) 得到一個python list,其中的內容就是 pred, loss, train_op的 Tensor。這通常是爲了在一個新的 session 中打開這張圖時,方便我們獲取想要的操作。比如我們可以直接工通過get_collection() 得到 train_op,然後通過sess.run(train_op)來開啓一段訓練,而無需重新構建 loss 和optimizer。

通過export_meta_graph保存圖,並且通過 add_to_collection 將 train_op 加入到 collection中:

with tf.Session() as sess:
  pred = model_network(X)
  loss=tf.reduce_mean(…,pred, …)
  train_op=tf.train.AdamOptimizer(lr).minimize(loss)
  tf.add_to_collection("training_collection", train_op)
  Meta_graph_def = 
      tf.train.export_meta_graph(tf.get_default_graph(), 'my_graph.meta')

通過 import_meta_graph將圖恢復(同時初始化爲本 Session的 default 圖),並且通過get_collection 重新獲得 train_op,以及通過 train_op 來開始一段訓練( sess.run() )。

with tf.Session() as new_sess:
  tf.train.import_meta_graph('my_graph.meta')
  train_op = tf.get_collection("training_collection")[0]
  new_sess.run(train_op)

更多的代碼例子可以在這篇文檔(https://www.tensorflow.org/api_guides/python/meta_graph)中的 Import a MetaGraph 章節中看到。

那麼,從 Meta Graph 中恢復構建的圖可以被訓練嗎?是可以的。Tensorflow的官方文檔 https://www.tensorflow.org/api_guides/python/meta_graph 說明了使用方法。這裏要特殊的說明一下,Meta Graph中雖然包含Variable的信息,卻沒有 Variable 的實際值。所以從Meta Graph中恢復的圖,其訓練是從隨機初始化的值開始的。訓練中Variable的實際值都保存在check-point中,如果要從之前訓練的狀態繼續恢復訓練,就要從check-point中restore。進一步讀一下Export Meta Graph的代碼,可以看到,事實上variables並沒有被export到meta_graph 中

https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/training/saver.py (1872行)

https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/framework/meta_graph.py (829,845行)

export_meta_graph/Import_meta_graph 就是用來進行 Meta Graph 讀寫的API。tf.train.saver.save() 在保存check-point的同時也會保存Meta Graph。但是在恢復圖時,tf.train.saver.restore() 只恢復 Variable,如果要從MetaGraph恢復圖,需要使用 import_meta_graph。這是其實爲了方便用戶,有時我們不需要從MetaGraph恢復的圖,而是需要在 python 中構建神經網絡圖,並恢復對應的 Variable。

Check-point
Check-point 裏全面保存了訓練某時間截面的信息,包括參數,超參數,梯度等等tf.train.Saver()/saver.restore() 則能夠完完整整保存和恢復神經網絡的訓練。Check-point分爲兩個文件保存Variable的二進制信息。ckpt文件保存了Variable的二進制信息,index文件用於保存 ckpt 文件中對應 Variable 的偏移量信息。

總結
Tensorflow 三種 API 所保存和恢復的圖是不一樣的。這三種圖是從Tensorflow框架設計的角度出發而定義的。但是從用戶的角度來看,TF文檔的寫作難免有些雲裏霧裏,弄不清他們的區別。需要讀一讀Tensorflow的代碼,做一些實驗來對他們進行辨析。

簡而言之,Tensorflow 在前端 Python 中構建圖,並且通過將該圖序列化到 ProtoBuf GraphDef,以方便在後端運行。在這個過程中,圖的保存、恢復和運行都通過 ProtoBuf 來實現。GraphDef,MetaGraph,以及Variable,Collection 和 Saver 等都有對應的 ProtoBuf 定義。ProtoBuf 的定義也決定了用戶能對圖進行的操作。例如用戶只能找到Node的前一個Node,卻無法得知自己的輸出會由哪個Node接收。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章