Tensorflow 測試RL算法,保存模型 並 讀取進行測試

保存模型

RL中,我們一般都把一個網絡結構寫在一個類裏面,保存的時候也是,可以如下寫一個 save_net 函數:

def save_net(self):
    saver = tf.train.Saver()
    save_path = saver.save(self.sess, "./dqn/model/file_name.ckpt")
    print("Save to path: ", save_path)

在RL算法進行完N輪的訓練之後,調用該函數進行模型保存:agent.save_net()
可以看到,會在model文件夾下多出四個文件:
在這裏插入圖片描述
也可以輸出保存前的參數,進行觀察,以便確認讀取模型時是否成功讀取了參數:

w1 = tf.get_default_graph().get_tensor_by_name('eval_net/l1/w1:0')  # 獲得variable對應的Tensor
print(self.sess.run(w1))  # run一下這個Tensor得到結果

讀取模型

首先注意,讀取模型用於測試時,我們需要保證用到的變量和訓練時的是一樣的,比如測試DQN模型的效果:

class Test4DQN:
    def __init__(self):
        self.sess = tf.Session()
        self._build_net()

    def _build_net(self):
        # 測試時,只需要建立 evaluate_net,用來選擇動作
        self.s = tf.placeholder(tf.float32, [None, 11])

        with tf.variable_scope('eval_net'):
            with tf.variable_scope('l1'):
                w1 = tf.Variable(np.arange(110).reshape((11, 10)), dtype=tf.float32, name="w1")
                b1 = tf.Variable(np.arange(10).reshape((1, 10)), dtype=tf.float32, name="b1")
                l1 = tf.nn.relu(tf.matmul(self.s, w1) + b1)

            with tf.variable_scope('l2'):
                w2 = tf.Variable(np.arange(240).reshape((10, 24)), dtype=tf.float32, name="w2")
                b2 = tf.Variable(np.arange(24).reshape((1, 24)), dtype=tf.float32, name="b2")
                self.q_eval = tf.matmul(l1, w2) + b2

        # 讀取模型參數
        saver = tf.train.Saver()
        init = tf.global_variables_initializer()
        self.sess.run(init)
        saver.restore(self.sess, "./xxxxx/model/file_name.ckpt")
        print(self.sess.run(w1))  # 可以再次輸出,和我們保存時的輸出結果進行對比,保證正確讀取
                    
                    
    def choose_action(self, observation):
        observation = observation[np.newaxis, :]
        actions_value = self.sess.run(self.q_eval, feed_dict={self.s: observation})
        action = np.argmax(actions_value)
        return action

總結一下,就是先初始化一下測試框架中定義的變量(注意層級和名稱需要對應,原來叫’‘w1’‘現在也要叫’‘w1’’),然後調用saver.restore(self.sess, "./xxxxx/model/file_name.ckpt"),即可將保存的網絡參數賦值給現在的網絡。

之後,和原來RL的流程一樣,只是不再需要保存記憶和訓練而已,最後可以得到測試的效果。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章