tf.get_variable()和tf.Variable()的区别

Variable

最近在学习TensorFlow的过程中, 看到在定义变量的时候有两种操作:
tf.get_variable()和tf.Variable()。

def weight_variable(shape):
    #initial = tf.truncated_normal(shape, stddev=0.1)
    #return tf.Variable(initial) # tf.get_variable()
    return tf.get_variable(name="w", shape=shape,
                           initializer=tf.truncated_normal_initializer(mean=0.0,
                                                                       stddev=0.1,
                                                                       seed=None,
                                                                       dtype=tf.float32))


with tf.Session() as sess:
    w1 = weight_variable([3, 3, 3, 1, 1])
    sess.run(tf.initialize_all_variables())
    print(sess.run(w1))

本以为两者没什么区别,但是博士师兄建议使用tf.get_variable()定义,不解查阅,于是总结了一下两者区别,如下:

tf.Variable()

tf.Variable(initial_value=None, 
            trainable=True, 
            collections=None, 
            validate_shape=True, 
            caching_device=None, 
            name=None, variable_def=None, 
            dtype=None, expected_shape=None, 
            import_scope=None)

tf.get_variable()

tf.get_variable(name, shape=None, 
                dtype=None, initializer=None, 
                regularizer=None, trainable=True, 
                collections=None, caching_device=None,
                partitioner=None, validate_shape=True,
                custom_getter=None)

先看一段代码:

import tensorflow as tf
w_1 = tf.Variable(3,name="w_1")
w_2 = tf.Variable(1,name="w_1")
print w_1.name
print w_2.name
#输出
#w_1:0
#w_1_1:0
import tensorflow as tf
w_1 = tf.get_variable(name="w_1",initializer=1)
w_2 = tf.get_variable(name="w_1",initializer=2)
#错误信息
#ValueError: Variable w_1 already exists, disallowed. Did
#you mean to set reuse=True in VarScope?

区别:

  • 使用tf.Variable时,如果检测到命名冲突,系统会自己处理。使用tf.get_variable()时,系统不会处理冲突,而会报错。
  • tf.Variable()每次都在创建新的对象,与name没有关系。而tf.get_variable()对于已经创建的同样name的变量对象,就直接把那个变量对象返回(类似于:共享变量),tf.get_variable() 会检查当前命名空间下是否存在同样name的变量,可以方便共享变量。
  • tf.get_variable():对于在上下文管理器中已经生成一个v的变量,若想通过tf.get_variable函数获取其变量,则可以通过reuse参数的设定为True来获取。
  • 还有一点,tf.get_variable()必须写name,否则报错(but instead was %s." % (name, shape))
    ValueError: Shape of a new variable (Tensor("truncated_normal:0", shape=(2, 3), dtype=float32)) must be fully defined, but instead was <unknown>.
    ),tf.Variable()不要求。
#需要注意的是tf.get_variable() 要配合reuse和tf.variable_scope() 使用
with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #创建两个名字一样的变量会报错 ValueError: Variable one/v already exists 
with tf.variable_scope("one", reuse = True): #注意reuse的作用。
    c = tf.get_variable("v", [1]) #c.name == "one/v:0" 成功共享,因为设置了reuse

assert a==c #Assertion is true, they refer to the same object.

对于tf.Variable():

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"  

assert d==e #AssertionError: they are different objects

【tensorflow 学习】tf.get_variable()和tf.Variable()的区别

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章