強化學習 ---baselines項目之 Atari遊戲的網絡結構解析

這個baselines項目設計的比較靈活,結構有點複雜。由於項目龐大,各個函數之間又是相互調用,有時候從一個函數追溯下去,可以追溯6,7層,每個函數的超參數又特別多,很容易把人搞暈。
      接下來只看DQN部分的源碼,其他無關的先不看,沿着一條線分解它!接下來進行一個遞歸遊戲,一層一層的深入探索,探索到盡頭再返回,當然中途適當剪剪枝,跟網絡圖無關的部分先不訪問!
     首先,我們找遞歸入口,在deepq下有個experiments,這下面全是實例,pong就是一個Atari遊戲的實驗。
在這裏插入圖片描述
以下是trian_pong的代碼


1.

from baselines import deepq
from baselines import bench
from baselines import logger
from baselines.common.atari_wrappers import make_atari
import numpy as np
np.seterr(invalid='ignore')

def main():
    logger.configure()
    env = make_atari('PongNoFrameskip-v4')
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=True,
        lr=1e-4,
        total_timesteps=int(1e7),
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
    )
    deepq.learn()

    model.save('pong_model.pkl')
    env.close()

if __name__ == '__main__':
    main()

以上代碼主要調用了deepq文件。並且仔細看下deepq.learn()方法,參數很多,我們先關注它的網絡結構,明顯卷積層conv和全連接層hiddens是屬於網絡結構部分。按住ctrl點擊learn,會跳轉到learn()方法位置。以下是learn()的超參數。別怕,看起來多,其實沒想象的複雜,因爲它的複雜是你想象不到的


2.

def learn(env,
          network,
          seed=None,
          lr=5e-4,
          total_timesteps=100000,
          buffer_size=50000,
          exploration_fraction=0.1,
          exploration_final_eps=0.02,
          train_freq=1,
          batch_size=32,
          print_freq=100,
          checkpoint_freq=10000,
          checkpoint_path=None,
          learning_starts=1000,
          gamma=1.0,
          target_network_update_freq=500,
          prioritized_replay=False,
          prioritized_replay_alpha=0.6,
          prioritized_replay_beta0=0.4,
          prioritized_replay_beta_iters=None,
          prioritized_replay_eps=1e-6,
          param_noise=False,
          callback=None,
          load_path=None,
          **network_kwargs
          ):

可是這裏面沒有出現參數convs等,那就說明這些參數是在 **network_kwargs(一個字典)裏面經過仔細對比可知:

**network_kwargs裏有{convs, hiddens,dueling}這3個
再看以下哪個函數用到了 **network_kwargs


3.

q_func = build_q_func(network, **network_kwargs)

只有這一條用到了,nice!點進去,發現新的天地


4.

def build_q_func(network, hiddens=[256], dueling=True, layer_norm=False, **network_kwargs):
    if isinstance(network, str):
        from baselines.common.models import get_network_builder
       # print('network:',network)
        network = get_network_builder(network)(**network_kwargs)
       # print('network:', network)
    def q_func_builder(input_placeholder, num_actions, scope, reuse=False):
        with tf.variable_scope(scope, reuse=reuse):
            latent = network(input_placeholder)
            if isinstance(latent, tuple):
                if latent[1] is not None:
                    raise NotImplementedError("DQN is not compatible with recurrent policies yet")
                latent = latent[0]

            latent = layers.flatten(latent)

            with tf.variable_scope("action_value"):
                action_out = latent
                for hidden in hiddens:
                    action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)
                    if layer_norm:
                        action_out = layers.layer_norm(action_out, center=True, scale=True)
                    action_out = tf.nn.relu(action_out)
                action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)

            if dueling:
                with tf.variable_scope("state_value"):
                    state_out = latent
                    for hidden in hiddens:
                        state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=None)
                        if layer_norm:
                            state_out = layers.layer_norm(state_out, center=True, scale=True)
                        state_out = tf.nn.relu(state_out)
                    state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
                action_scores_mean = tf.reduce_mean(action_scores, 1)
                action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
                q_out = state_score + action_scores_centered
            else:
                q_out = action_scores
            return q_out

    return q_func_builder

很接近了,這個明顯就是在構建網絡結構了, **network_kwargs裏面使用了兩個參數,還剩下‘conv’ 沒用,看3可以知道它返回的是一個函數,也就是這裏面的q_func_builder,看這個函數前,還要進入下一層探索,先看下開頭的這行


5.

network = get_network_builder(network)(**network_kwargs)

從剛剛的傳參我們知道network = “conv_only” ,是一個字符串。通過一個字符串,返回了一個函數,並且再把**network_kwargs(也就是卷積層的參數)參數傳遞給返回的這個函數, 目測最終返回一個卷積層的網絡結構。先點進去看一下:


6.

def get_network_builder(name):

    if callable(name):  #函數用於檢查一個對象是否是可調用的
        # print('name',name)
        return name
    elif name in mapping:
       # print('mapping',mapping)
        return mapping[name]

    else:
        raise ValueError('Unknown network type: {}'.format(name))

這段代碼很簡潔:先檢測下name是否是可調用的參數,再檢測是否在mapping字典裏,然後返回函數名。顯然這裏的“conv_only”是從mapping裏面去取。
我們把mappling輸出來看一下


7.

print(mapping)
output:
mapping {'mlp': <function mlp at 0x000001C90316B5E8>, 'cnn': <function cnn at 0x000001C90316B678>, 'impala_cnn': <function impala_cnn at 0x000001C90316B708>, 'cnn_small': <function cnn_small at 0x000001C90316B798>, 'lstm': <function lstm at 0x000001C90316B828>, 'cnn_lstm': <function cnn_lstm at 0x000001C90316B8B8>, 'impala_cnn_lstm': <function impala_cnn_lstm at 0x000001C90316B948>, 'cnn_lnlstm': <function cnn_lnlstm at 0x000001C90316B9D8>, 'conv_only': <function conv_only at 0x000001C90316BA68>}
network: <function conv_only.<locals>.network_fn at 0x000001C9030AF678>

發現有那麼多函數,那麼這些函數是什麼時候加入mapping的呢?
在該文件搜索下mapping,發現register()函數:


8.

mapping = {}

def register(name):
    def _thunk(func):
        mapping[name] = func
        return func
    return _thunk

就是簡單的將name和函數名存入mapping
再搜索register,發現了很多register的註解


9.

@register("mlp")
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):...

@register("cnn")
def cnn(**conv_kwargs):...

@register("conv_only")
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):....

其實加上@register(“xxx”)後就會調用上面8函數直接加入mapping,這就正好跟mapping裏的函數對應上了。註解是個好東西
到此爲止,我也不知道跳了多少層了。只要腦袋裏的線路清晰,就不會混亂。那麼直接關注conv_only函數吧


10.

@register("conv_only")
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
    def network_fn(X):
        out = tf.cast(X, tf.float32) / 255.
        with tf.variable_scope("convnet"):
            for num_outputs, kernel_size, stride in convs:
                out = tf.contrib.layers.convolution2d(out,
                                           num_outputs=num_outputs,
                                           kernel_size=kernel_size,
                                           stride=stride,
                                           activation_fn=tf.nn.relu,
                                           **conv_kwargs)

        return out
    return network_fn

先看下整體結構,果然沒錯,跟我們在5的時候說的一樣之前傳入的**network_kwargs就變成這個函數的convs了。返回了一個network_fn函數 ,再傳入 圖像的輸入參數----------就是這裏的X參數。現在仔細看下這個函數!

  • convs是一個列表,列表裏是一個元組。X是輸入的圖像參數,是一個shape=[batch,84,84,4]的Tensor。
  • 先量化成0-1的浮點數,再遍歷convs裏面的元素,搭建卷積網絡
  • num_outputs是通道數,也就是卷積核個數, kernel_size卷積核大小, stride 是上下的跨度大小 這些是卷積網絡裏常用的
  • 輸出out是一個【batch,size,size,64】的Tensor


    接下來返回4在來看
def build_q_func(network, hiddens=[256], dueling=True, layer_norm=False, **network_kwargs):
    if isinstance(network, str):
        from baselines.common.models import get_network_builder
       # print('network:',network)
        network = get_network_builder(network)(**network_kwargs)
       # print('network:', network)
    def q_func_builder(input_placeholder, num_actions, scope, reuse=False):
        with tf.variable_scope(scope, reuse=reuse):
            latent = network(input_placeholder)
            if isinstance(latent, tuple):
                if latent[1] is not None:
                    raise NotImplementedError("DQN is not compatible with recurrent policies yet")
                latent = latent[0]

            latent = layers.flatten(latent)

            with tf.variable_scope("action_value"):
                action_out = latent
                for hidden in hiddens:
                    action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)
                    if layer_norm:
                        action_out = layers.layer_norm(action_out, center=True, scale=True)
                    action_out = tf.nn.relu(action_out)
                action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)

            if dueling:
                with tf.variable_scope("state_value"):
                    state_out = latent
                    for hidden in hiddens:
                        state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=None)
                        if layer_norm:
                            state_out = layers.layer_norm(state_out, center=True, scale=True)
                        state_out = tf.nn.relu(state_out)
                    state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
                action_scores_mean = tf.reduce_mean(action_scores, 1)
                action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
                q_out = state_score + action_scores_centered
            else:
                q_out = action_scores
            return q_out

    return q_func_builder

根據上面的分析,在執行network = get_network_builder(network)(**network_kwargs)語句後,network相當於是10裏的network_fn(X)函數

接着分析q_func_builder():

  • latent = network(input_placeholder) 首先將輸入傳給network,返回一個latent是【batch,,,64】的Tensor
  • 進行判斷和錯誤提示後,latent = layers.flatten(latent)通過這個把四維的卷積層,拉伸成二維,準備做全連接層
  • action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)這個進行循環遍歷做全連接
  • action_out = layers.layer_norm(action_out, center=True, scale=True)這個是把全連接層進行標準化,這個操作來自於這篇論文
  • action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)然後接上輸出層
  • 下面是關於dueling優化的,原理就不想講了,簡單的說dueling網絡就是在原來的網絡上增加一路,這一路輸出只有一個結點表示當前狀態下的價值。這個價值要減去上面那一路的平均值,然後將兩個輸出合併成一個out
  • 最後還是把整個函數返回出去

那麼到此爲止,這個網絡的構造已經分析完了。現在再回到3來看下

q_func = build_q_func(network, **network_kwargs)

它只需要調用這一條語句,就可以得到一個構建網絡圖的函數了。q_func只是一個函數,還未進行傳參和網絡圖的構建。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章