TensorFlow中的變量管理

本文參考《TensorFlow實戰Google深度學習框架》一書,總結了一些在TensorFlow中與變量管理相關的一些API和使用技巧

1.創建變量

TensorFlow中可以通過tf.Variable和tf.get_variable兩個函數來創建變量,兩者基本功能相同,但是用法存在差別。

#下面兩個定義是等價的,只不過變量的名稱不同
v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = 'v1')
v2 = tf.get_variable('v2', shape = [1], initializer = tf.constant_initializer(1.0))
tf.get_variable在函數調用時名稱是必須輸入的,而tf.Variable則不是必選的。此外tf.get_variable函數調用時提供的維度和初始化方法與tf.Variable也類似,TF中提供的initializer函數與隨機數和常量的生成函數大部分是一一對應的。


在變量初始化的過成中,tf.Variable可以使用其他已經初始化的變量對其進行初始化

v3 = tf.Variable(v1.initialized_value() * 2, name = 'v3')
而tf.get_variable函數可以通過tf.variable_scope函數來生成一個上下文管理器,並且明確指定在這個上下文管理器中,tf.get_variable將直接獲取已經生成的變量。
import tensorflow as tf
with tf.variable_scope('scope1', reuse = False) as scope:
   print(tf.get_variable_scope().reuse)
   x1 = tf.get_variable('x', [1])
with tf.variable_scope('scope1', reuse = True) as scope:
   print(tf.get_variable_scope().reuse)
   x2 = tf.get_variable('x', [1])
   print(x1 == x2)#輸出True
tf.variable_scope函數的reuse默認爲None。當reuse=False或者None時,tf.get_variable將創建新的變量,如果同名的變量已經存在了,那麼會報錯。如果reuse=True,tf.get_variable函數將會直接獲取已經創建的變量,如果變量不存在,則會報錯。

此外tf.variable_scope函數可以嵌套,當有外層已經被指定爲reuse=True之後,內層嵌套的其他同類函數的reuse都會被默認設置爲True,除非有顯示地指明。此處就不再以代碼示人。

2.命名空間

tf.variable_scope生成的上下文管理器會創建一個TF的命名空間,在該空間內創建的變量name都會帶上這個命名空間名作爲前綴。命名空間隨着tf.variable_scope的嵌套,也可以進行嵌套,會有不同的name屬性。但是如果在內層中使用的變量標識符與外層使用的相同,則該變量會被更新。如果是並列的沒有包含關係的命名空間,使用相同的標識符表示變量則不會有衝突。
import tensorflow as tf
sess = tf.InteractiveSession()
with tf.variable_scope('scope1') as scope:
   print(tf.get_variable_scope().reuse)
   x1 = tf.get_variable('x', [1], initializer = tf.constant_initializer(1))
   print(x1.name)#print scope1/x:0
   with tf.variable_scope('scope2') as scope:
      x2 = tf.get_variable('x', [1],initializer = tf.constant_initializer(2))#如果這裏使用x1作爲標識符,則x1會被更新爲2
      print(x2.name)#print scope1/scope2/x:0
with tf.variable_scope('scope1', reuse = True) as scope:
   x3 = tf.get_variable('x', [1])
   print(x3.name)
   print(x1 == x3)#print scope1/x:0
with tf.variable_scope('', reuse = True) as scope:#這裏只能用空的名稱
   print(tf.get_variable_scope().reuse)
   x4 = tf.get_variable('scope1/x', [1], initializer = tf.constant_initializer(3))#這裏的初始化沒有作用,仍然爲1
   print(x4.name)#print scope1/x:0
   print(x4 == x1)#True 如果scope2中的標識符也叫x1,那麼會輸出Fasle

sess.run(tf.initialize_all_variables())
print(x1.eval(), x2.eval(), x3.eval(), x4.eval())#print 1 2 1 1
print(tf.get_default_graph().get_tensor_by_name('scope1/x:0').eval())#1
print(tf.get_default_graph().get_tensor_by_name('scope1/scope2/x:0').eval())#2

3.初始化

TF中使用tf.initialize_all_variables()對所有的變量進行初始化。一般有兩種調用方法。
sess.run(tf.initialize_all_variables())
#tf.initialize_all_variables().run()
也可以單獨對某個變量進行初始化,但是不常見
sess.run(x1.initializer)



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章