Tensorflow源碼解析5 -- 圖的邊 - Tensor

1 概述

前文兩篇文章分別講解了TensorFlow核心對象Graph,和Graph的節點Operation。Graph另外一大成員,即爲其邊Tensor。邊用來表示計算的數據,它經過上游節點計算後得到,然後傳遞給下游節點進行運算。本文講解Graph的邊Tensor,以及TensorFlow中的變量。

2 前端邊Tensor數據結構

Tensor作爲Graph的邊,使得節點Operation之間建立了連接。上游源節點Operation經過計算得到數據Tensor,然後傳遞給下游目標節點,是一個典型的生產者-消費者關係。下面來看Tensor的數據結構

@tf_export("Tensor")
class Tensor(_TensorLike):
  def __init__(self, op, value_index, dtype):
    # 源節點,tensor的生產者,會計算得到tensor
    self._op = op

    # tensor在源節點的輸出邊集合中的索引。源節點可能會有多條輸出邊
    # 利用op和value_index即可唯一確定tensor。
    self._value_index = value_index

    # tensor中保存的數據的數據類型
    self._dtype = dtypes.as_dtype(dtype)

    # tensor的shape,可以得到張量的rank,維度等信息
    self._shape_val = tensor_shape.unknown_shape()

    # 目標節點列表,tensor的消費者,會使用該tensor來進行計算
    self._consumers = []

    #
    self._handle_data = None
    self._id = uid()

Tensor中主要包含兩類信息,一個是Graph結構信息,如邊的源節點和目標節點。另一個則是它所保存的數據信息,例如數據類型,shape等。

3 後端邊Edge數據結構

後端中的Graph主要成員也是節點node和邊edge。節點node爲計算算子Operation,邊Edge爲算子所需要的數據,或者代表節點間的依賴關係。這一點和Python中的定義相似。邊Edge的持有它的源節點和目標節點的指針,從而將兩個節點連接起來。下面看Edge類的定義。

class Edge {
   private:
      Edge() {}

      friend class EdgeSetTest;
      friend class Graph;
      // 源節點, 邊的數據就來源於源節點的計算。源節點是邊的生產者
      Node* src_;

      // 目標節點,邊的數據提供給目標節點進行計算。目標節點是邊的消費者
      Node* dst_;

      // 邊id,也就是邊的標識符
      int id_;

      // 表示當前邊爲源節點的第src_output_條邊。源節點可能會有多條輸出邊
      int src_output_;

      // 表示當前邊爲目標節點的第dst_input_條邊。目標節點可能會有多條輸入邊。
      int dst_input_;
};

Edge既可以承載tensor數據,提供給節點Operation進行運算,也可以用來表示節點之間有依賴關係。對於表示節點依賴的邊,其src_output_, dst_input_均爲-1,此時邊不承載任何數據。

4 常量constant

TensorFlow的常量constant,最終包裝成了一個Tensor。通過tf.constant(10),返回一個Tensor對象。

@tf_export("constant")
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
  # 算子註冊到默認Graph中
  g = ops.get_default_graph()
    
  # 對常量值value的處理
  tensor_value = attr_value_pb2.AttrValue()
  tensor_value.tensor.CopyFrom(
      tensor_util.make_tensor_proto(
          value, dtype=dtype, shape=shape, verify_shape=verify_shape))

  # 對常量值的類型dtype進行處理
  dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)

  # 構造並註冊類型爲“Const”的算子到Graph中,從算子的outputs中取出輸出的tensor。
  const_tensor = g.create_op(
      "Const", [], [dtype_value.type],
      attrs={"value": tensor_value,
             "dtype": dtype_value},
      name=name).outputs[0]
  return const_tensor

tf.constant的過程爲

  1. 獲取默認graph
  2. 對常量值value和常量值的類型dtype進行處理
  3. 構造並註冊類型爲“Const”的算子到默認graph中,從算子的outputs中取出輸出的tensor。

此時只是圖的構造過程,tensor並未承載數據,僅表示Operation輸出的一個符號句柄。經過tensor.eval()或session.run()後,纔會啓動graph的執行,並得到數據。

5 變量Variable

Variable構造器

通過tf.Variable()構造一個變量,代碼如下,我們僅分析入參。

@tf_export("Variable")
class Variable(object):
  def __init__(self,
               initial_value=None,
               trainable=True,
               collections=None,
               validate_shape=True,
               caching_device=None,
               name=None,
               variable_def=None,
               dtype=None,
               expected_shape=None,
               import_scope=None,
               constraint=None):
# initial_value: 初始值,爲一個tensor,或者可以被包裝爲tensor的值
# trainable:是否可以訓練,如果爲false,則訓練時不會改變
# collections:變量要加入哪個集合中,有全局變量集合、本地變量集合、可訓練變量集合等。默認加入全局變量集合中
# dtype:變量的類型

主要的入參代碼中已經給出了註釋。Variable可以接受一個tensor或者可以被包裝爲tensor的值,來作爲初始值。事實上,Variable可以看做是Tensor的包裝器,它重載了Tensor的幾乎所有操作,是對Tensor的進一步封裝。

Variable初始化

變量只有初始化後才能使用,初始化時將initial_value初始值賦予Variable內部持有的Tensor。通過運行變量的初始化器可以對變量進行初始化,也可以執行全局初始化器。如下

y = tf.Variable([5.3])

with tf.Session() as sess:
    initialization = tf.global_variables_initializer()
    print sess.run(y)

Variable集合

Variable被劃分到不同的集合中,方便後續操作。常見的集合有

  1. 全局變量:全局變量可以在不同進程中共享,可運用在分佈式環境中。變量默認會加入到全局變量集合中。通過tf.global_variables()可以查詢全局變量集合。其op標示爲GraphKeys.GLOBAL_VARIABLES

    @tf_export("global_variables")
    def global_variables(scope=None):
      return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope)
  2. 本地變量:運行在進程內的變量,不能跨進程共享。通常用來保存臨時變量,如訓練迭代次數epoches。通過tf.local_variables()可以查詢本地變量集合。其op標示爲GraphKeys.LOCAL_VARIABLES

    @tf_export("local_variables")
    def local_variables(scope=None):
        return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope)
  3. 可訓練變量:一般模型參數會放到可訓練變量集合中,訓練時,做這些變量會得到改變。不在這個集合中的變量則不會得到改變。默認會放到此集合中。通過tf.trainable_variables()可以查詢。其op標示爲GraphKeys.TRAINABLE_VARIABLES

    @tf_export("trainable_variables")
    def trainable_variables(scope=None):
      return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope)

其他集合還有model_variables,moving_average_variables。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章