強化學習快餐教程(2) - atari遊戲
不知道看了上節的內容,大家有沒有找到讓杆不倒的好算法。
現在我們晉階一下,向世界上第一種大規模的遊戲機atari前進。
太空入侵者
可以通過
pip install atari_py
來安裝atari遊戲。
下面我們以SpaceInvaders-v0爲例看下Atari遊戲的環境的特點。
圖形版
在太空入侵者中,支持的輸入有6種,一個是什麼也不做,一個是開火,另4個是控制方向:
- 0: NOOP
- 1: FIRE
- 2: UP
- 3: RIGHT
- 4: LEFT
- 5: DOWN
我們從環境中獲取的信息是什麼呢?很不幸,是一個(210, 160, 3)的圖片,顯示出來是這樣的:
我們寫代碼把這個環境搭起來。策略嘛,我就原地不動一直開火。
import gym
from skimage import io
env = gym.make('SpaceInvaders-v0')
status = env.reset()
for step in range(1000):
env.render()
thisstep = 1
status, reward, done, info = env.step(thisstep)
jpgname = './pic-%d.jpg' % step
io.imsave(jpgname,status)
print(reward)
if done:
print('dead in %d steps' % step)
break
env.close()
大家可以通過保存下來的pic-x.jpg來直觀觀察遊戲的情況,比如我在第138步時,打中了一個5分的入侵者。
太空入侵者這個遊戲的策略比起cartpole,需要分析圖像,這個是不足。但是它也是有好處的,就是reward參數現在會把分數返回給我們。總算是分數處理上不需要搞圖像分析了。
下面我來寫個算法吧,以1/4的概率左右移動,另外3/4開火:
import gym
env = gym.make('SpaceInvaders-v0')
status = env.reset()
def policy(step):
state_pool = [3,4,3,3,4,4,3,3,3,4,4,4,3,3,3,3,4,4,4,4]
if step % 4 == 0:
pos = step / 4
result = state_pool[int(pos % (len(state_pool)))]
return result
else:
return 1
for step in range(10000):
env.render()
thisstep = policy(step)
print(thisstep)
status, reward, done, info = env.step(thisstep)
#print(reward)
if done:
print('dead in %d steps' % step)
break
env.close()
內存版
如果圖像分析做起來不方便的話,gym還爲我們提供了RAM版的。就是將遊戲機中的128個字節的內存信息提供給我們。
下面是env.reset的128個字節的例子:
[ 0 7 0 68 241 162 34 183 68 13 124 255 255 50 255 255 0 36
63 63 63 63 63 63 82 0 23 43 35 117 180 0 36 63 63 63
63 63 63 110 0 23 1 60 126 126 126 126 255 255 255 195 60 126
126 126 126 255 255 255 195 60 126 126 126 126 255 255 255 195 0 0
48 3 129 0 0 0 0 0 0 246 246 63 63 246 246 63 63 0
21 24 0 52 82 196 246 20 7 0 226 0 0 0 0 0 21 63
0 128 171 0 255 0 189 0 0 0 0 0 99 255 0 0 235 254
192 242]
輸入的部分跟圖像版是一樣的,我們代碼修改如下:
import gym
env = gym.make('SpaceInvaders-ram-v0')
status = env.reset()
print(status)
def policy(step):
state_pool = [3,4,3,3,4,4,3,3,3,4,4,4,3,3,3,3,4,4,4,4]
if step % 4 == 0:
pos = step / 4
result = state_pool[int(pos % (len(state_pool)))]
return result
else:
return 1
for step in range(10000):
env.render()
thisstep = policy(step)
#print(thisstep)
status, reward, done, info = env.step(thisstep)
#print(reward)
if done:
print('dead in %d steps' % step)
break
env.close()
breakout
下面我們再來一個彈球遊戲。
彈球遊戲的輸入是4個值。
圖片版的:
import gym
from skimage import io
env = gym.make('Breakout-v0')
status = env.reset()
#print(status)
print(env.action_space)
def policy(step):
if step % 2 == 0:
return 2
else:
return 3
for step in range(100):
env.render()
thisstep = policy(step)
#print(thisstep)
status, reward, done, info = env.step(thisstep)
jpgname = './pic-%d.jpg' % step
io.imsave(jpgname,status)
#print(reward)
if done:
print('dead in %d steps' % step)
break
env.close()
RAM版的例子:
import gym
env = gym.make('Breakout-ram-v0')
status = env.reset()
print(status)
print(env.action_space)
def policy(step):
return step % 4
for step in range(100):
env.render()
thisstep = policy(step)
#print(thisstep)
status, reward, done, info = env.step(thisstep)
#print(reward)
if done:
print('dead in %d steps' % step)
break
env.close()
RAM的初始值:
[ 63 63 63 63 63 63 255 255 255 255 255 255 255 255 255 255 255 255
255 255 255 255 255 255 255 255 255 255 255 255 192 192 192 192 192 192
255 255 255 255 255 255 255 255 255 255 255 255 255 240 0 0 255 0
0 240 0 5 0 0 6 0 70 182 134 198 22 38 54 70 88 6
146 0 8 0 0 0 0 0 0 241 0 242 0 242 25 241 5 242
0 0 255 0 228 0 0 0 0 0 0 0 0 0 0 0 0 0
8 0 255 255 255 255 255 255 255 0 0 5 0 0 186 214 117 246
219 242]