tensorflow 多Agent 靈活保存、更新Graph的各部分參數(tf.variable_scope(), tf.get_collection(), tf.train.Saver())

當使用tensorflow搭建機器學習模型時,簡單的模型可以直接從輸入X開始,一層層地設置變量和operation,最終得到輸出Y^\hat{Y},並和label:Y一起計算出Loss Function,然後調用優化器最小化Loss即可。

然而,複雜的模型往往涉及到這樣的問題,有時候,我們並不希望像上面一樣,直接對模型進行端到端的更新,又或者模型涉及到多個agent,有多個優化目標,這時候,我們就不能像上面一樣簡單的直接對Graph內的全體參數直接進行梯度下降更新,而是需要靈活控制各部分參數。

主要依賴的函數是tf.variable_scope() 和 tf.get_collection()。

tf.variable_scope()

import tensorflow as tf
n1 = 50
input_dim = 100
output_dim = 1
X = tf.placeholder(tf.float32, [None, input_dim], 'X')
Y1 = tf.placeholder(tf.float32, [None, output_dim], 'Y1')
Y2= tf.placeholder(tf.float32, [None, output_dim], 'Y2')

with tf.variable_scope('Agent'):
    with tf.variable_scope('layer1'):
        w1 = tf.get_variable('w1', [input_dim, n1], trainable=True)
        b1 = tf.get_variable('b1', [1, n1], trainable=True)
        s = tf.matmul(X, w1) + b1
    with tf.variable_scope('layer2_1'):
        w21 = tf.get_variable('w21', [n1, output_dim], trainable=True)
        b21 = tf.get_variable('b21', [1, output_dim], trainable=True)
        y1 = tf.matmul(s, w21) + b21
    with tf.variable_scope('layer2_2'):
        w22 = tf.get_variable('w22', [n1, output_dim], trainable=True)
        b22 = tf.get_variable('b22', [1, output_dim], trainable=True)
        y2 = tf.matmul(s, w22) + b22

如上所示,我們定義了一個圖:其中所有的變量都在’Agent’內,‘Agent/layer1’內爲第一層變量,其輸出s被兩個互相獨立的層’Agent/layer2_1’和‘Agent/layer2_2’共享。

tf.get_collection()

有了上面的定義,我們就可以使用tf.get_collection()把相應scope內的變量取出來,輸出爲這些變量構成的列表:

agent_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Agent')
layer1_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Agent/layer1')
layer21_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Agent/layer2_1')
layer22_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Agent/layer2_2')

操作:參數訓練與保存讀取

有了上面的準備工作,我們就可以進行一些靈活的操作:如自主的控制其中的部分參數進行訓練,或者對其中部分訓練好的參數進行保存和讀取。

如,通過 var_list 選項指定需要被更新的參數範圍:

loss1 = tf.losses.mean_squared_error(labels=Y1, predictions=y1)
loss2 = tf.losses.mean_squared_error(labels=Y2, predictions=y2)

train1 = tf.train.AdamOptimizer(lr).minimize(loss1, var_list=layer1_params+layer21_params)
train2 = tf.train.AdamOptimizer(lr).minimize(loss2, var_list=layer1_params+layer22_params)

或者,使用tf.train.Saver()對相應部分參數進行保存和讀取:

sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver1 = tf.train.Saver(agent_params)
saver2 = tf.train.Saver(layer1_params)
saver3 = tf.train.Saver(layer21_params)
saver4 = tf.train.Saver(layer22_params)
save_path1 = saver1.save(sess,"./models/agent.ckpt")
save_path2 = saver2.save(sess,"./models/layer1.ckpt")
save_path3 = saver3.save(sess,"./models/layer21.ckpt")
save_path4 = saver4.save(sess,"./models/layer22.ckpt")
saver1.restore(sess,"./models/agent.ckpt")
saver2.restore(sess,"./models/layer1.ckpt")
saver3.restore(sess,"./models/layer21.ckpt")
saver4.restore(sess,"./models/layer22.ckpt")
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章