tf.name_scope(‘xxx’) // tf.variable_scope(‘xxx’) 權值共享

這兩天在coding中發現要設計的深度學習網絡模塊需要權值共享,已經記錄了紙質筆記,這裏在博客上也記錄一下吧,其實最常見的權值共享就是GAN網絡了,因爲generate data 和那個real data都需要經過discriminator進行訓練,意味着共享一套變量信息。用到的就是tf.variable_scope(‘scope_name’)來完成的。

with tf.variable_scope("discriminator"):
    predict_real, pred_real_dict = build_discriminator(input, target)
with tf.variable_scope("discriminator", reuse=True):
    predict_fake, pred_fake_dict = build_discriminator(input, transmission_layer)

1.tf.name_scope(‘xxx’) 

目的:主要結合 tf.Variable() 來使用,方便參數命名管理

with tf.name_scope('conv1') as scope:
    weights1 = tf.Variable([1.0, 2.0], name='weights')
    bias1 = tf.Variable([0.3], name='bias')

# 下面是在另外一個命名空間來定義變量的
with tf.name_scope('conv2') as scope:
    weights2 = tf.Variable([4.0, 2.0], name='weights')
    bias2 = tf.Variable([0.33], name='bias')

這裏有兩個變量空間,一個是“conv1”,一個是“conv2”;

不過這段代碼變成

with tf.name_scope('conv1') as scope:
    weights1 = tf.Variable([1.0, 2.0], name='weights')
    bias1 = tf.Variable([0.3], name='bias')


with tf.name_scope('conv2') as scope:
    weights2 = tf.Variable([4.0, 2.0], name='weights')
    bias2 = tf.Variable([0.33], name='bias')
#----------------------------------------------------------
with tf.name_scope('conv1') as scope:
    weights1 = tf.Variable([1.0, 2.0], name='weights')
    bias1 = tf.Variable([0.3], name='bias')

with tf.name_scope('conv2') as scope:
    weights2 = tf.Variable([4.0, 2.0], name='weights')
    bias2 = tf.Variable([0.33], name='bias')

那麼就有兩個空間

conv1和conv2;conv1_1和conv2_1。

順便提及 tf.Variable()是每次創建新的變量,第二段代碼如果用tf.get_variable()的話那麼就是會報錯的,這個會檢查創建的這個變量是否存在了。

 

2.tf.variable_scope(‘xxx’)

with tf.variable_scope('rnn') as scope:
    sess = tf.Session()
    train_rnn = RNN(train_config) #RNN就是我們要共享的那個網絡模塊了
    
    scope.reuse_variables()  #這句話是共享的關鍵所在!!!
    
    test_rnn = RNN(test_config)
    sess.run(tf.global_variables_initializer())
#這種感覺更常用
with tf.variable_scope("image_fliters") as scope:
    result1 = my_image_fliter(image1)
with tf.variable_scope("image_fliters",reuse=True):#"image_fliters"換成scope也是可以的
    result2 = my_image_fliter(image2)
        
    

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章