Variable
最近在学习TensorFlow的过程中, 看到在定义变量的时候有两种操作:
tf.get_variable()和tf.Variable()。
def weight_variable(shape):
#initial = tf.truncated_normal(shape, stddev=0.1)
#return tf.Variable(initial) # tf.get_variable()
return tf.get_variable(name="w", shape=shape,
initializer=tf.truncated_normal_initializer(mean=0.0,
stddev=0.1,
seed=None,
dtype=tf.float32))
with tf.Session() as sess:
w1 = weight_variable([3, 3, 3, 1, 1])
sess.run(tf.initialize_all_variables())
print(sess.run(w1))
本以为两者没什么区别,但是博士师兄建议使用tf.get_variable()定义,不解查阅,于是总结了一下两者区别,如下:
tf.Variable()
tf.Variable(initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None, variable_def=None,
dtype=None, expected_shape=None,
import_scope=None)
tf.get_variable()
tf.get_variable(name, shape=None,
dtype=None, initializer=None,
regularizer=None, trainable=True,
collections=None, caching_device=None,
partitioner=None, validate_shape=True,
custom_getter=None)
先看一段代码:
import tensorflow as tf
w_1 = tf.Variable(3,name="w_1")
w_2 = tf.Variable(1,name="w_1")
print w_1.name
print w_2.name
#输出
#w_1:0
#w_1_1:0
import tensorflow as tf
w_1 = tf.get_variable(name="w_1",initializer=1)
w_2 = tf.get_variable(name="w_1",initializer=2)
#错误信息
#ValueError: Variable w_1 already exists, disallowed. Did
#you mean to set reuse=True in VarScope?
区别:
- 使用tf.Variable时,如果检测到命名冲突,系统会自己处理。使用tf.get_variable()时,系统不会处理冲突,而会报错。
- tf.Variable()每次都在创建新的对象,与name没有关系。而tf.get_variable()对于已经创建的同样name的变量对象,就直接把那个变量对象返回(类似于:共享变量),tf.get_variable() 会检查当前命名空间下是否存在同样name的变量,可以方便共享变量。
- tf.get_variable():对于在上下文管理器中已经生成一个v的变量,若想通过tf.get_variable函数获取其变量,则可以通过reuse参数的设定为True来获取。
- 还有一点,tf.get_variable()必须写name,否则报错(
but instead was %s." % (name, shape))
),tf.Variable()不要求。
ValueError: Shape of a new variable (Tensor("truncated_normal:0", shape=(2, 3), dtype=float32)) must be fully defined, but instead was <unknown>.
#需要注意的是tf.get_variable() 要配合reuse和tf.variable_scope() 使用
with tf.variable_scope("one"):
a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
b = tf.get_variable("v", [1]) #创建两个名字一样的变量会报错 ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True): #注意reuse的作用。
c = tf.get_variable("v", [1]) #c.name == "one/v:0" 成功共享,因为设置了reuse
assert a==c #Assertion is true, they refer to the same object.
对于tf.Variable():
with tf.variable_scope("two"):
d = tf.get_variable("v", [1]) #d.name == "two/v:0"
e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"
assert d==e #AssertionError: they are different objects