tf.name_scope(‘xxx’) // tf.variable_scope(‘xxx’) 权值共享

这两天在coding中发现要设计的深度学习网络模块需要权值共享,已经记录了纸质笔记,这里在博客上也记录一下吧,其实最常见的权值共享就是GAN网络了,因为generate data 和那个real data都需要经过discriminator进行训练,意味着共享一套变量信息。用到的就是tf.variable_scope(‘scope_name’)来完成的。

with tf.variable_scope("discriminator"):
    predict_real, pred_real_dict = build_discriminator(input, target)
with tf.variable_scope("discriminator", reuse=True):
    predict_fake, pred_fake_dict = build_discriminator(input, transmission_layer)

1.tf.name_scope(‘xxx’) 

目的:主要结合 tf.Variable() 来使用,方便参数命名管理

with tf.name_scope('conv1') as scope:
    weights1 = tf.Variable([1.0, 2.0], name='weights')
    bias1 = tf.Variable([0.3], name='bias')

# 下面是在另外一个命名空间来定义变量的
with tf.name_scope('conv2') as scope:
    weights2 = tf.Variable([4.0, 2.0], name='weights')
    bias2 = tf.Variable([0.33], name='bias')

这里有两个变量空间,一个是“conv1”,一个是“conv2”;

不过这段代码变成

with tf.name_scope('conv1') as scope:
    weights1 = tf.Variable([1.0, 2.0], name='weights')
    bias1 = tf.Variable([0.3], name='bias')


with tf.name_scope('conv2') as scope:
    weights2 = tf.Variable([4.0, 2.0], name='weights')
    bias2 = tf.Variable([0.33], name='bias')
#----------------------------------------------------------
with tf.name_scope('conv1') as scope:
    weights1 = tf.Variable([1.0, 2.0], name='weights')
    bias1 = tf.Variable([0.3], name='bias')

with tf.name_scope('conv2') as scope:
    weights2 = tf.Variable([4.0, 2.0], name='weights')
    bias2 = tf.Variable([0.33], name='bias')

那么就有两个空间

conv1和conv2;conv1_1和conv2_1。

顺便提及 tf.Variable()是每次创建新的变量,第二段代码如果用tf.get_variable()的话那么就是会报错的,这个会检查创建的这个变量是否存在了。

 

2.tf.variable_scope(‘xxx’)

with tf.variable_scope('rnn') as scope:
    sess = tf.Session()
    train_rnn = RNN(train_config) #RNN就是我们要共享的那个网络模块了
    
    scope.reuse_variables()  #这句话是共享的关键所在!!!
    
    test_rnn = RNN(test_config)
    sess.run(tf.global_variables_initializer())
#这种感觉更常用
with tf.variable_scope("image_fliters") as scope:
    result1 = my_image_fliter(image1)
with tf.variable_scope("image_fliters",reuse=True):#"image_fliters"换成scope也是可以的
    result2 = my_image_fliter(image2)
        
    

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章