Tensorflow源碼解析3 -- TensorFlow核心對象 - Graph

1 Graph概述

計算圖Graph是TensorFlow的核心對象,TensorFlow的運行流程基本都是圍繞它進行的。包括圖的構建、傳遞、剪枝、按worker分裂、按設備二次分裂、執行、註銷等。因此理解計算圖Graph對掌握TensorFlow運行尤爲關鍵。

2 默認Graph

默認圖替換

之前講解Session的時候就說過,一個Session只能run一個Graph,但一個Graph可以運行在多個Session中。常見情況是,session會運行全局唯一的隱式的默認的Graph,operation也是註冊到這個Graph中。

也可以顯示創建Graph,並調用as_default()使他替換默認Graph。在該上下文管理器中創建的op都會註冊到這個graph中。退出上下文管理器後,則恢復原來的默認graph。一般情況下,我們不用顯式創建Graph,使用系統創建的那個默認Graph即可。

print tf.get_default_graph()

with tf.Graph().as_default() as g:
    print tf.get_default_graph() is g
    print tf.get_default_graph()

print tf.get_default_graph()

輸出如下

<tensorflow.python.framework.ops.Graph object at 0x106329fd0>
True
<tensorflow.python.framework.ops.Graph object at 0x18205cc0d0>
<tensorflow.python.framework.ops.Graph object at 0x10d025fd0>

由此可見,在上下文管理器中,當前線程的默認圖被替換了,而退出上下文管理後,則恢復爲了原來的默認圖。

默認圖管理

默認graph和默認session一樣,也是線程作用域的。當前線程中,永遠都有且僅有一個graph爲默認圖。TensorFlow同樣通過棧來管理線程的默認graph。

@tf_export("Graph")
class Graph(object):
    # 替換線程默認圖
    def as_default(self):
        return _default_graph_stack.get_controller(self)
    
    # 棧式管理,push pop
    @tf_contextlib.contextmanager
    def get_controller(self, default):
        try:
          context.context_stack.push(default.building_function, default.as_default)
        finally:
          context.context_stack.pop()

替換默認圖採用了堆棧的管理方式,通過push pop操作進行管理。獲取默認圖的操作如下,通過默認graph棧_default_graph_stack來獲取。

@tf_export("get_default_graph")
def get_default_graph():
  return _default_graph_stack.get_default()

下面來看_default_graph_stack的創建

_default_graph_stack = _DefaultGraphStack()
class _DefaultGraphStack(_DefaultStack):  
  def __init__(self):
    # 調用父類來創建
    super(_DefaultGraphStack, self).__init__()
    self._global_default_graph = None
    
class _DefaultStack(threading.local):
  def __init__(self):
    super(_DefaultStack, self).__init__()
    self._enforce_nesting = True
    # 和默認session棧一樣,本質上也是一個list
    self.stack = []

_default_graph_stack的創建如上所示,最終和默認session棧一樣,本質上也是一個list。

3 前端Graph數據結構

Graph數據結構

理解一個對象,先從它的數據結構開始。我們先來看Python前端中,Graph的數據結構。Graph主要的成員變量是Operation和Tensor。Operation是Graph的節點,它代表了運算算子。Tensor是Graph的邊,它代表了運算數據。

@tf_export("Graph")
class Graph(object):
    def __init__(self):
           # 加線程鎖,使得註冊op時,不會有其他線程註冊op到graph中,從而保證共享graph是線程安全的
        self._lock = threading.Lock()
        
        # op相關數據。
        # 爲graph的每個op分配一個id,通過id可以快速索引到相關op。故創建了_nodes_by_id字典
        self._nodes_by_id = dict()  # GUARDED_BY(self._lock)
        self._next_id_counter = 0  # GUARDED_BY(self._lock)
        # 同時也可以通過name來快速索引op,故創建了_nodes_by_name字典
        self._nodes_by_name = dict()  # GUARDED_BY(self._lock)
        self._version = 0  # GUARDED_BY(self._lock)
        
        # tensor相關數據。
        # 處理tensor的placeholder
        self._handle_feeders = {}
        # 處理tensor的read操作
        self._handle_readers = {}
        # 處理tensor的move操作
        self._handle_movers = {}
        # 處理tensor的delete操作
        self._handle_deleters = {}

下面看graph如何添加op的,以及保證線程安全的。

  def _add_op(self, op):
    # graph被設置爲final後,就是隻讀的了,不能添加op了。
    self._check_not_finalized()
    
    # 保證共享graph的線程安全
    with self._lock:
      # 將op以id和name分別構建字典,添加到_nodes_by_id和_nodes_by_name字典中,方便後續快速索引
      self._nodes_by_id[op._id] = op
      self._nodes_by_name[op.name] = op
      self._version = max(self._version, op._id)

GraphKeys 圖分組

每個Operation節點都有一個特定的標籤,從而實現節點的分類。相同標籤的節點歸爲一類,放到同一個Collection中。標籤是一個唯一的GraphKey,GraphKey被定義在類GraphKeys中,如下

@tf_export("GraphKeys")
class GraphKeys(object):
    GLOBAL_VARIABLES = "variables"
    QUEUE_RUNNERS = "queue_runners"
    SAVERS = "savers"
    WEIGHTS = "weights"
    BIASES = "biases"
    ACTIVATIONS = "activations"
    UPDATE_OPS = "update_ops"
    LOSSES = "losses"
    TRAIN_OP = "train_op"
    # 省略其他

name_scope 節點命名空間

使用name_scope對graph中的節點進行層次化管理,上下層之間通過斜槓分隔。

# graph節點命名空間
g = tf.get_default_graph()
with g.name_scope("scope1"):
    c = tf.constant("hello, world", name="c")
    print c.op.name

    with g.name_scope("scope2"):
        c = tf.constant("hello, world", name="c")
        print c.op.name

輸出如下

scope1/c
scope1/scope2/c  # 內層的scope會繼承外層的,類似於棧,形成層次化管理


4 後端Graph數據結構

Graph

先來看graph.h文件中的Graph類的定義,只看關鍵代碼

 class Graph {
     private:
      // 所有已知的op計算函數的註冊表
      FunctionLibraryDefinition ops_;

      // GraphDef版本號
      const std::unique_ptr<VersionDef> versions_;

      // 節點node列表,通過id來訪問
      std::vector<Node*> nodes_;

      // node個數
      int64 num_nodes_ = 0;

      // 邊edge列表,通過id來訪問
      std::vector<Edge*> edges_;

      // graph中非空edge的數目
      int num_edges_ = 0;

      // 已分配了內存,但還沒使用的node和edge
      std::vector<Node*> free_nodes_;
      std::vector<Edge*> free_edges_;
 }

後端中的Graph主要成員也是節點node和邊edge。節點node爲計算算子Operation,邊爲算子所需要的數據,或者代表節點間的依賴關係。這一點和Python中的定義相似。邊Edge的持有它的源節點和目標節點的指針,從而將兩個節點連接起來。下面看Edge類的定義。

Edge

class Edge {
     private:
      Edge() {}

      friend class EdgeSetTest;
      friend class Graph;
      // 源節點, 邊的數據就來源於源節點的計算。源節點是邊的生產者
      Node* src_;

      // 目標節點,邊的數據提供給目標節點進行計算。目標節點是邊的消費者
      Node* dst_;

      // 邊id,也就是邊的標識符
      int id_;

      // 表示當前邊爲源節點的第src_output_條邊。源節點可能會有多條輸出邊
      int src_output_;

      // 表示當前邊爲目標節點的第dst_input_條邊。目標節點可能會有多條輸入邊。
      int dst_input_;
};

Edge既可以承載tensor數據,提供給節點Operation進行運算,也可以用來表示節點之間有依賴關係。對於表示節點依賴的邊,其src_output_, dst_input_均爲-1,此時邊不承載任何數據。

下面來看Node類的定義。

Node

class Node {
 public:
    // NodeDef,節點算子Operation的信息,比如op分配到哪個設備上了,op的名字等,運行時有可能變化。
      const NodeDef& def() const;
    
    // OpDef, 節點算子Operation的元數據,不會變的。比如Operation的入參列表,出參列表等
      const OpDef& op_def() const;
 private:
      // 輸入邊,傳遞數據給節點。可能有多條
      EdgeSet in_edges_;

      // 輸出邊,節點計算後得到的數據。可能有多條
      EdgeSet out_edges_;
}

節點Node中包含的主要數據有輸入邊和輸出邊的集合,從而能夠由Node找到跟他關聯的所有邊。Node中還包含NodeDef和OpDef兩個成員。NodeDef表示節點算子的信息,運行時可能會變,創建Node時會new一個NodeDef對象。OpDef表示節點算子的元信息,運行時不會變,創建Node時不需要new OpDef,只需要從OpDef倉庫中取出即可。因爲元信息是確定的,比如Operation的入參個數等。

由Node和Edge,即可以組成圖Graph,通過任何節點和任何邊,都可以遍歷完整圖。Graph執行計算時,按照拓撲結構,依次執行每個Node的op計算,最終即可得到輸出結果。入度爲0的節點,也就是依賴數據已經準備好的節點,可以併發執行,從而提高運行效率。

系統中存在默認的Graph,初始化Graph時,會添加一個Source節點和Sink節點。Source表示Graph的起始節點,Sink爲終止節點。Source的id爲0,Sink的id爲1,其他節點id均大於1.

5 Graph運行時生命週期

Graph是TensorFlow的核心對象,TensorFlow的運行均是圍繞Graph進行的。運行時Graph大致經過了以下階段

  1. 圖構建:client端用戶將創建的節點註冊到Graph中,一般不需要顯示創建Graph,使用系統創建的默認的即可。
  2. 圖發送:client通過session.run()執行運行時,將構建好的整圖序列化爲GraphDef後,傳遞給master
  3. 圖剪枝:master先反序列化拿到Graph,然後根據session.run()傳遞的fetches和feeds列表,反向遍歷全圖full graph,實施剪枝,得到最小依賴子圖。
  4. 圖分裂:master將最小子圖分裂爲多個Graph Partition,並註冊到多個worker上。一個worker對應一個Graph Partition。
  5. 圖二次分裂:worker根據當前可用硬件資源,如CPU GPU,將Graph Partition按照op算子設備約束規範(例如tf.device(’/cpu:0’),二次分裂到不同設備上。每個計算設備對應一個Graph Partition。
  6. 圖運行:對於每一個計算設備,worker依照op在kernel中的實現,完成op的運算。設備間數據通信可以使用send/recv節點,而worker間通信,則使用GRPC或RDMA協議。

這些階段根據TensorFlow運行時的不同,會進行不同的處理。運行時有兩種,本地運行時和分佈式運行時。故Graph生命週期到後面分析本地運行時和分佈式運行時的時候,再詳細講解。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章