強化學習快餐教程(3) - 一條命令搞定atari遊戲

強化學習快餐教程(3) - 一條命令搞定atari遊戲

通過上節的例子,我們試驗出來,就算是像cartpole這樣讓一個杆子不倒這樣的小模型,都不是特別容易搞定的。

那麼像太空入侵者這麼複雜的問題,建模都建不出來,算法該怎麼寫?

別急,我們從強化學習的基礎來講起,學習馬爾可夫決策過程,瞭解貝爾曼方程、最優值函數、最優策略及其求解。然後學習動態規劃法、蒙特卡洛法、時間差分法、值函數近似法、策略梯度法。再然後我們借用深度學習的武器來武裝強化學習算法,我們會學習DQN算法族,講解2013版的基於Replay Memory的DQN算法,還有2015年增加了Target網絡的新DQN算法,還有Double DQN、優先級回放DQN和Dueling DQN,以及PG算法族的DPG,Actor-Critic,DDPG,以及A3C算法等等。

有的同學表示已經看暈了,除了堆了一堆名詞什麼也沒看懂。別怕,我們程序員解決這樣的問題統統是上代碼說話,我們馬上給大家一條命令就能解決太空入侵者以及其它更復雜的遊戲。
有的同學表示還沒看暈,嗯,這樣的話後面我們增加一些推導公式和解讀最新論文的環節 :)

baseline登場

下面,就是見證奇蹟的時刻。對於太空入侵者這麼複雜的問題,我們程序員用什麼辦法來解決它呢?答案是,調庫!
庫到哪裏找呢?洪七公說過:凡毒蛇出沒之處,七步內必有解救蛇毒之藥。 其他毒物,無不如此,這是天地間萬物生克的至理。OpenAI成天跟這些難題打交道,肯定有其解法。
沒錯,我們要用的庫,就是openAI的baselines庫。

調庫之前,我們先下載源碼:

git clone https://github.com/openai/baselines

然後安裝一下庫:

pip install -e .

見證奇蹟的時刻到了,我們用一條命令,就可以解決這些複雜的atari遊戲,以太空入侵者爲例:

python -m baselines.run --alg=ppo2 --env=SpaceInvadersNoFrameskip-v4

baselines.run有兩個參數:一個是算法,ppo2是OpenAI自家的最近策略優化算法Proximal Policy Optimization Algorithms;另一個是遊戲環境,SpaceInvaders是太空入侵者遊戲,後面的NoFrameskip-v4是控制參數,之前我們使用v0或v1中,爲了模擬人的控制,不讓控制太精準,所以都是有4幀左右的重複。而在訓練的時候就去掉這個負優化,使訓練的效率更高。

然後baselines就會和atari模擬器打交道去進行訓練了,輸入類似於下面這樣:

Stepping environment...
Done.
--------------------------------------
| eplenmean               | 1.12e+03 |
| eprewmean               | 836      |
| fps                     | 721      |
| loss/approxkl           | 0.0043   |
| loss/clipfrac           | 0.144    |
| loss/policy_entropy     | 0.786    |
| loss/policy_loss        | -0.0137  |
| loss/value_loss         | 0.148    |
| misc/explained_variance | 0.946    |
| misc/nupdates           | 7.30e+03 |
| misc/serial_timesteps   | 9.35e+05 |
| misc/time_elapsed       | 1.08e+04 |
| misc/total_timesteps    | 7.48e+06 |
--------------------------------------
Stepping environment...
Done.
--------------------------------------
| eplenmean               | 1.12e+03 |
| eprewmean               | 836      |
| fps                     | 764      |
| loss/approxkl           | 0.00372  |
| loss/clipfrac           | 0.134    |
| loss/policy_entropy     | 0.713    |
| loss/policy_loss        | -0.0144  |
| loss/value_loss         | 0.118    |
| misc/explained_variance | 0.949    |
| misc/nupdates           | 7.31e+03 |
| misc/serial_timesteps   | 9.35e+05 |
| misc/time_elapsed       | 1.08e+04 |
| misc/total_timesteps    | 7.48e+06 |
--------------------------------------
Stepping environment...
Done.
--------------------------------------
| eplenmean               | 1.12e+03 |
| eprewmean               | 831      |
| fps                     | 749      |
| loss/approxkl           | 0.0039   |
| loss/clipfrac           | 0.13     |
| loss/policy_entropy     | 0.706    |
| loss/policy_loss        | -0.0143  |
| loss/value_loss         | 0.108    |
| misc/explained_variance | 0.941    |
| misc/nupdates           | 7.31e+03 |
| misc/serial_timesteps   | 9.35e+05 |
| misc/time_elapsed       | 1.08e+04 |
| misc/total_timesteps    | 7.48e+06 |
--------------------------------------

運行個一千萬步左右,就可以看到分數的結果了。

如果我們不喜歡PPO算法,也可以換成其它的經典算法,比如DQN,把算法參數改成deepq就好。我們舉個用DQN算法的例子,遊戲我們也換成Breakout彈珠遊戲吧。

python -m baselines.run --alg=deepq --env=BreakoutNoFrameskip-v4 --num_timesteps=1e6

大家注意到,我們還增加了第三個參數,執行多少步迭代的數目,這裏我們選10萬步。論文中一般選1000萬步。

模型辛苦訓練出來了,後面遊戲中還要用呢,我們再增加一個參數保存起來:

python -m baselines.run --alg=ppo2 --env=SpaceInvadersNoFrameskip-v4 --num_timesteps=1e6 --save_path=./models/cartpole_1M_ppo2

使用深度Q學習解決cartpole問題

通過上面的樣例,我們對於強化學習有了信心。下面我們學習用代碼調用baselines的方法,我們首先以之前搞不定的cartpole那個立不住的杆子開始入手。

4步法訓練深度強化模型

通過baselines調用強化學習庫非常簡單,只要4步就可以了:
第一步,創建cartpole遊戲環境,這個我們已經非常熟悉了,使用gym.make來創建一個env

env = gym.make("CartPole-v0")

第二步,確定訓練結束的條件。這也跟庫和算法無關,是領域知識。對於cartpole來說,能堅持199步不倒,也就是杆子的傾角不超過12度,且小車沒有超出範圍,就算是成功:

def callback(lcl, _glb):
    # stop training if reward exceeds 199
    is_solved = lcl['t'] > 100 and sum(lcl['episode_rewards'][-101:-1]) / 100 >= 199
    return is_solved

第三步,選擇一種強化學習算法並調用其訓練方法,我們這裏選用最經典的DQN算法:

act = deepq.learn(
    env,
    network='mlp',
    lr=1e-3,
    total_timesteps=100000,
    buffer_size=50000,
    exploration_fraction=0.1,
    exploration_final_eps=0.02,
    print_freq=10,
    callback=callback
)

參數中,env是要用來訓練的遊戲環境,callback是我們剛纔寫的成功條件,total_timesteps是要練訓多少步,我們選10萬步就足夠了。
第四步,將訓練好的模型保存起來:

act.save("cartpole_model.pkl")

經過這4步,我們就可以在一不懂算法原理,二不懂超參數是什麼的情況下,成功解決cartpole問題,源碼整合下如下:

import gym

from baselines import deepq


def callback(lcl, _glb):
    # stop training if reward exceeds 199
    is_solved = lcl['t'] > 100 and sum(lcl['episode_rewards'][-101:-1]) / 100 >= 199
    return is_solved


env = gym.make("CartPole-v0")
act = deepq.learn(
    env,
    network='mlp',
    lr=1e-3,
    total_timesteps=100000,
    buffer_size=50000,
    exploration_fraction=0.1,
    exploration_final_eps=0.02,
    print_freq=10,
    callback=callback
)
print("Saving model to cartpole_model.pkl")
act.save("cartpole_model.pkl")

運行模型

模型訓練好了,我們就拿來跑一下驗證效果吧,我們把策略生成交給剛纔的deepq.learn,其餘的大家都已經很熟悉了,就是標準的運行遊戲的過程:

import gym

from baselines import deepq

env = gym.make("CartPole-v0")
act = deepq.learn(env, network='mlp', total_timesteps=0,load_path="cartpole_model.pkl")

obs, done = env.reset(), False
episode_rew = 0
while not done:
    env.render()
    obs, rew, done, _ = env.step(act(obs[None])[0])
    episode_rew += rew
    print("Episode reward", episode_rew)

從運行圖上,我們就可以看到pole再也不倒了。從打印出來的reward結果上也證明了這一點:

Loaded model from cartpole_model.pkl
Episode reward 1.0
Episode reward 2.0
Episode reward 3.0
Episode reward 4.0
...
Episode reward 198.0
Episode reward 199.0
Episode reward 200.0
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章