這個baselines項目設計的比較靈活,結構有點複雜。由於項目龐大,各個函數之間又是相互調用,有時候從一個函數追溯下去,可以追溯6,7層,每個函數的超參數又特別多,很容易把人搞暈。
接下來只看DQN部分的源碼,其他無關的先不看,沿着一條線分解它!接下來進行一個遞歸遊戲,一層一層的深入探索,探索到盡頭再返回,當然中途適當剪剪枝,跟網絡圖無關的部分先不訪問!
首先,我們找遞歸入口,在deepq下有個experiments,這下面全是實例,pong就是一個Atari遊戲的實驗。
以下是trian_pong的代碼
1.
from baselines import deepq
from baselines import bench
from baselines import logger
from baselines.common.atari_wrappers import make_atari
import numpy as np
np.seterr(invalid='ignore')
def main():
logger.configure()
env = make_atari('PongNoFrameskip-v4')
env = bench.Monitor(env, logger.get_dir())
env = deepq.wrap_atari_dqn(env)
model = deepq.learn(
env,
"conv_only",
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256],
dueling=True,
lr=1e-4,
total_timesteps=int(1e7),
buffer_size=10000,
exploration_fraction=0.1,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
)
deepq.learn()
model.save('pong_model.pkl')
env.close()
if __name__ == '__main__':
main()
以上代碼主要調用了deepq文件。並且仔細看下deepq.learn()方法,參數很多,我們先關注它的網絡結構,明顯卷積層conv和全連接層hiddens是屬於網絡結構部分。按住ctrl點擊learn,會跳轉到learn()方法位置。以下是learn()的超參數。別怕,看起來多,其實沒想象的複雜,因爲它的複雜是你想象不到的
2.
def learn(env,
network,
seed=None,
lr=5e-4,
total_timesteps=100000,
buffer_size=50000,
exploration_fraction=0.1,
exploration_final_eps=0.02,
train_freq=1,
batch_size=32,
print_freq=100,
checkpoint_freq=10000,
checkpoint_path=None,
learning_starts=1000,
gamma=1.0,
target_network_update_freq=500,
prioritized_replay=False,
prioritized_replay_alpha=0.6,
prioritized_replay_beta0=0.4,
prioritized_replay_beta_iters=None,
prioritized_replay_eps=1e-6,
param_noise=False,
callback=None,
load_path=None,
**network_kwargs
):
可是這裏面沒有出現參數convs等,那就說明這些參數是在 **network_kwargs(一個字典)裏面經過仔細對比可知:
**network_kwargs裏有{convs, hiddens,dueling}這3個
再看以下哪個函數用到了 **network_kwargs
3.
q_func = build_q_func(network, **network_kwargs)
只有這一條用到了,nice!點進去,發現新的天地
4.
def build_q_func(network, hiddens=[256], dueling=True, layer_norm=False, **network_kwargs):
if isinstance(network, str):
from baselines.common.models import get_network_builder
# print('network:',network)
network = get_network_builder(network)(**network_kwargs)
# print('network:', network)
def q_func_builder(input_placeholder, num_actions, scope, reuse=False):
with tf.variable_scope(scope, reuse=reuse):
latent = network(input_placeholder)
if isinstance(latent, tuple):
if latent[1] is not None:
raise NotImplementedError("DQN is not compatible with recurrent policies yet")
latent = latent[0]
latent = layers.flatten(latent)
with tf.variable_scope("action_value"):
action_out = latent
for hidden in hiddens:
action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)
if layer_norm:
action_out = layers.layer_norm(action_out, center=True, scale=True)
action_out = tf.nn.relu(action_out)
action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)
if dueling:
with tf.variable_scope("state_value"):
state_out = latent
for hidden in hiddens:
state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=None)
if layer_norm:
state_out = layers.layer_norm(state_out, center=True, scale=True)
state_out = tf.nn.relu(state_out)
state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
action_scores_mean = tf.reduce_mean(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
q_out = state_score + action_scores_centered
else:
q_out = action_scores
return q_out
return q_func_builder
很接近了,這個明顯就是在構建網絡結構了, **network_kwargs裏面使用了兩個參數,還剩下‘conv’ 沒用,看3可以知道它返回的是一個函數,也就是這裏面的q_func_builder,看這個函數前,還要進入下一層探索,先看下開頭的這行
5.
network = get_network_builder(network)(**network_kwargs)
從剛剛的傳參我們知道network = “conv_only” ,是一個字符串。通過一個字符串,返回了一個函數,並且再把**network_kwargs(也就是卷積層的參數)參數傳遞給返回的這個函數, 目測最終返回一個卷積層的網絡結構。先點進去看一下:
6.
def get_network_builder(name):
if callable(name): #函數用於檢查一個對象是否是可調用的
# print('name',name)
return name
elif name in mapping:
# print('mapping',mapping)
return mapping[name]
else:
raise ValueError('Unknown network type: {}'.format(name))
這段代碼很簡潔:先檢測下name是否是可調用的參數,再檢測是否在mapping字典裏,然後返回函數名。顯然這裏的“conv_only”是從mapping裏面去取。
我們把mappling輸出來看一下
7.
print(mapping)
output:
mapping {'mlp': <function mlp at 0x000001C90316B5E8>, 'cnn': <function cnn at 0x000001C90316B678>, 'impala_cnn': <function impala_cnn at 0x000001C90316B708>, 'cnn_small': <function cnn_small at 0x000001C90316B798>, 'lstm': <function lstm at 0x000001C90316B828>, 'cnn_lstm': <function cnn_lstm at 0x000001C90316B8B8>, 'impala_cnn_lstm': <function impala_cnn_lstm at 0x000001C90316B948>, 'cnn_lnlstm': <function cnn_lnlstm at 0x000001C90316B9D8>, 'conv_only': <function conv_only at 0x000001C90316BA68>}
network: <function conv_only.<locals>.network_fn at 0x000001C9030AF678>
發現有那麼多函數,那麼這些函數是什麼時候加入mapping的呢?
在該文件搜索下mapping,發現register()函數:
8.
mapping = {}
def register(name):
def _thunk(func):
mapping[name] = func
return func
return _thunk
就是簡單的將name和函數名存入mapping
再搜索register,發現了很多register的註解
9.
@register("mlp")
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):...
@register("cnn")
def cnn(**conv_kwargs):...
@register("conv_only")
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):....
其實加上@register(“xxx”)後就會調用上面8函數直接加入mapping,這就正好跟mapping裏的函數對應上了。註解是個好東西
到此爲止,我也不知道跳了多少層了。只要腦袋裏的線路清晰,就不會混亂。那麼直接關注conv_only函數吧
10.
@register("conv_only")
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
def network_fn(X):
out = tf.cast(X, tf.float32) / 255.
with tf.variable_scope("convnet"):
for num_outputs, kernel_size, stride in convs:
out = tf.contrib.layers.convolution2d(out,
num_outputs=num_outputs,
kernel_size=kernel_size,
stride=stride,
activation_fn=tf.nn.relu,
**conv_kwargs)
return out
return network_fn
先看下整體結構,果然沒錯,跟我們在5的時候說的一樣之前傳入的**network_kwargs就變成這個函數的convs了。返回了一個network_fn函數 ,再傳入 圖像的輸入參數----------就是這裏的X參數。現在仔細看下這個函數!
- convs是一個列表,列表裏是一個元組。X是輸入的圖像參數,是一個shape=[batch,84,84,4]的Tensor。
- 先量化成0-1的浮點數,再遍歷convs裏面的元素,搭建卷積網絡
- num_outputs是通道數,也就是卷積核個數, kernel_size卷積核大小, stride 是上下的跨度大小 這些是卷積網絡裏常用的
- 輸出out是一個【batch,size,size,64】的Tensor
接下來返回4在來看
def build_q_func(network, hiddens=[256], dueling=True, layer_norm=False, **network_kwargs):
if isinstance(network, str):
from baselines.common.models import get_network_builder
# print('network:',network)
network = get_network_builder(network)(**network_kwargs)
# print('network:', network)
def q_func_builder(input_placeholder, num_actions, scope, reuse=False):
with tf.variable_scope(scope, reuse=reuse):
latent = network(input_placeholder)
if isinstance(latent, tuple):
if latent[1] is not None:
raise NotImplementedError("DQN is not compatible with recurrent policies yet")
latent = latent[0]
latent = layers.flatten(latent)
with tf.variable_scope("action_value"):
action_out = latent
for hidden in hiddens:
action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)
if layer_norm:
action_out = layers.layer_norm(action_out, center=True, scale=True)
action_out = tf.nn.relu(action_out)
action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)
if dueling:
with tf.variable_scope("state_value"):
state_out = latent
for hidden in hiddens:
state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=None)
if layer_norm:
state_out = layers.layer_norm(state_out, center=True, scale=True)
state_out = tf.nn.relu(state_out)
state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
action_scores_mean = tf.reduce_mean(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
q_out = state_score + action_scores_centered
else:
q_out = action_scores
return q_out
return q_func_builder
根據上面的分析,在執行network = get_network_builder(network)(**network_kwargs)
語句後,network相當於是10裏的network_fn(X)函數
接着分析q_func_builder():
latent = network(input_placeholder)
首先將輸入傳給network,返回一個latent是【batch,,,64】的Tensor- 進行判斷和錯誤提示後,
latent = layers.flatten(latent)
通過這個把四維的卷積層,拉伸成二維,準備做全連接層 action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)
這個進行循環遍歷做全連接action_out = layers.layer_norm(action_out, center=True, scale=True)
這個是把全連接層進行標準化,這個操作來自於這篇論文action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)
然後接上輸出層- 下面是關於dueling優化的,原理就不想講了,簡單的說dueling網絡就是在原來的網絡上增加一路,這一路輸出只有一個結點表示當前狀態下的價值。這個價值要減去上面那一路的平均值,然後將兩個輸出合併成一個out
- 最後還是把整個函數返回出去
那麼到此爲止,這個網絡的構造已經分析完了。現在再回到3來看下
q_func = build_q_func(network, **network_kwargs)
它只需要調用這一條語句,就可以得到一個構建網絡圖的函數了。q_func只是一個函數,還未進行傳參和網絡圖的構建。