保存模型
RL中,我們一般都把一個網絡結構寫在一個類裏面,保存的時候也是,可以如下寫一個 save_net 函數:
def save_net(self):
saver = tf.train.Saver()
save_path = saver.save(self.sess, "./dqn/model/file_name.ckpt")
print("Save to path: ", save_path)
在RL算法進行完N輪的訓練之後,調用該函數進行模型保存:agent.save_net()
可以看到,會在model文件夾下多出四個文件:
也可以輸出保存前的參數,進行觀察,以便確認讀取模型時是否成功讀取了參數:
w1 = tf.get_default_graph().get_tensor_by_name('eval_net/l1/w1:0') # 獲得variable對應的Tensor
print(self.sess.run(w1)) # run一下這個Tensor得到結果
讀取模型
首先注意,讀取模型用於測試時,我們需要保證用到的變量和訓練時的是一樣的,比如測試DQN模型的效果:
class Test4DQN:
def __init__(self):
self.sess = tf.Session()
self._build_net()
def _build_net(self):
# 測試時,只需要建立 evaluate_net,用來選擇動作
self.s = tf.placeholder(tf.float32, [None, 11])
with tf.variable_scope('eval_net'):
with tf.variable_scope('l1'):
w1 = tf.Variable(np.arange(110).reshape((11, 10)), dtype=tf.float32, name="w1")
b1 = tf.Variable(np.arange(10).reshape((1, 10)), dtype=tf.float32, name="b1")
l1 = tf.nn.relu(tf.matmul(self.s, w1) + b1)
with tf.variable_scope('l2'):
w2 = tf.Variable(np.arange(240).reshape((10, 24)), dtype=tf.float32, name="w2")
b2 = tf.Variable(np.arange(24).reshape((1, 24)), dtype=tf.float32, name="b2")
self.q_eval = tf.matmul(l1, w2) + b2
# 讀取模型參數
saver = tf.train.Saver()
init = tf.global_variables_initializer()
self.sess.run(init)
saver.restore(self.sess, "./xxxxx/model/file_name.ckpt")
print(self.sess.run(w1)) # 可以再次輸出,和我們保存時的輸出結果進行對比,保證正確讀取
def choose_action(self, observation):
observation = observation[np.newaxis, :]
actions_value = self.sess.run(self.q_eval, feed_dict={self.s: observation})
action = np.argmax(actions_value)
return action
總結一下,就是先初始化一下測試框架中定義的變量(注意層級和名稱需要對應,原來叫’‘w1’‘現在也要叫’‘w1’’),然後調用saver.restore(self.sess, "./xxxxx/model/file_name.ckpt")
,即可將保存的網絡參數賦值給現在的網絡。
之後,和原來RL的流程一樣,只是不再需要保存記憶和訓練而已,最後可以得到測試的效果。