tensorflow中共享變量 tf.get_variable 和命名空間 tf.variable_scope

tensorflow中有很多需要變量共享的場合,比如在多個GPU上訓練網絡時網絡參數和訓練數據就需要共享。

tf通過 tf.get_variable() 可以建立或者獲取一個共享的變量。 tf.get_variable函數的作用從tf的註釋裏就可以看出來-- ‘Gets an existing variable with this name or create a new one’。

 

與 tf.get_variable 函數相對的還有一個 tf.Variable 函數,兩者的區別是:

 

  • tf.Variable定義變量的時候會自動檢測命名衝突並自行處理,例如已經定義了一個名稱是 ‘wg_1’的變量,再使用tf.Variable定義名稱是‘wg_1’的變量,會自動把後一個變量的名稱更改爲‘wg_1_0’,實際相當於創建了兩個變量,tf.Variable不可以創建共享變量。
  • tf.get_variable定義變量的時候不會自動處理命名衝突,如果遇到重名的變量並且創建該變量時沒有設置爲共享變量,tf會直接報錯。

 

變量可以共享之後還有一個問題就是當模型很大很複雜的時候,變量和操作的數量也比較龐大,爲了方便對這些變量進行管理,維護條理清晰的graph結構,tf建立了一套共享機制,通過 變量作用域(命名空間,variable_scope)實現對變量的共享和管理。例如,cnn的每一層中,均有weights和biases這兩個變量,通過tf.variable_scope()爲每一卷積層命名,就可以防止變量命名重複。

與 tf.variable_scope相對的還有一個 tf.name_scope 函數,兩者的區別是:

 

  • tf.name_scope 主要用於管理一個圖(graph)裏面的各種操作,返回的是一個以scope_name命名的context manager。一個graph會維護一個name_space的堆,每一個namespace下面可以定義各種op或者子namespace,實現一種層次化有條理的管理,避免各個op之間命名衝突。
  • tf.variable_scope 一般與tf.name_scope()配合使用,用於管理一個圖(graph)中變量的名字,避免變量之間的命名衝突,tf.variable_scope允許在一個variable_scope下面共享變量。
# coding: utf-8
# Created by -牧野- https://blog.csdn.net/dcrmg/article/details/79794501
import tensorflow as tf

# 定義的基本等價
v1 = tf.get_variable("v", shape=[1], initializer= tf.constant_initializer(1.0))
v2 = tf.Variable(tf.constant(1.0, shape=[1]), name="v")

with tf.variable_scope("abc"):
    v3=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))

# 在變量作用域內定義變量,不同變量作用域內的變量命名可以相同
with tf.variable_scope("xyz"):
    v4=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))

with tf.variable_scope("xyz", reuse=True):
    v5 = tf.get_variable("v")
    v6 = tf.get_variable("v",[1])

with tf.variable_scope("foo"):
    v7 = tf.get_variable("v", [1])

    # 通過 tf.get_variable_scope().reuse_variables() 設置以下的變量是共享變量;
    # 如果不加,v8的定義會由於重名而報錯
    tf.get_variable_scope().reuse_variables()
    v8 = tf.get_variable("v", [1])
assert v7 is v8


with tf.variable_scope("foo_1") as foo_scope:
    v = tf.get_variable("v", [1])
with tf.variable_scope(foo_scope):
    w = tf.get_variable("w", [1])
with tf.variable_scope(foo_scope, reuse=True):
    v1 = tf.get_variable("v", [1])
    w1 = tf.get_variable("w", [1])
assert v1 is v
assert w1 is w


with tf.variable_scope("foo1"):
    with tf.name_scope("bar1"):
        v_1 = tf.get_variable("v", [1])
        x_1 = 1.0 + v_1
assert v_1.name == "foo1/v:0"
assert x_1.op.name == "foo1/bar1/add"


print v1==v2  # False
print v3==v4  # False 不同變量作用域中
print v3.name  # abc/v:0
print v4==v5  # 輸出爲True
print v5==v6  # True
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章