這裏通過zoe走迷宮例子再次學習Q-learning。與強化學習系列(1)中思想一致,其區別主要是通過兩個類,迷宮環境Maze和zoe大腦QLearningTable來規範化程序,同時在運行函數步驟來清晰化Q學習的過程。
Part 1. 迷宮環境class Maze
Part 2. zoe大腦class QLearningTable
Part 3. 運行函數
Part 4. 演示效果
Part 1. 迷宮環境class Maze
import numpy as np
import time
import sys
# 窗口界面庫
if sys.version_info.major == 2:
import Tkinter as tk
else:
import tkinter as tk
UNIT = 40 # 像素
MAZE_H = 4 # 網格高度
MAZE_W = 4 # 網格寬度
class Maze(tk.Tk, object): # 新類繼承父類tk.Tk
def __init__(self):
super(Maze, self).__init__() # super類繼承方法,初始化
self.action_space = ['u', 'd', 'l', 'r']
self.n_actions = len(self.action_space)
self.title('maze')
self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT)) # format格式化函數 geometry分辨率函數
self._build_maze()
# 迷宮佈景
def _build_maze(self):
self.canvas = tk.Canvas(self, bg='white', # tk畫圖組件
height = MAZE_H * UNIT,
width = MAZE_W * UNIT)
for c in range(0, MAZE_W * UNIT, UNIT): # 網格
x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
self.canvas.create_line(x0, y0, x1, y1)
for r in range(0, MAZE_H * UNIT, UNIT):
x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
self.canvas.create_line(x0, y0, x1, y1)
# 原點
origin = np.array([20, 20])
# hell 地獄
hell1_center = origin + np.array([UNIT * 2, UNIT])
self.hell1 = self.canvas.create_rectangle(
hell1_center[0] - 15, hell1_center[1] - 15, # [5, 5]
hell1_center[0] + 15, hell1_center[1] + 15, # [35, 35]
fill='black')
hell2_center = origin + np.array([UNIT, UNIT * 2])
self.hell2 = self.canvas.create_rectangle(
hell2_center[0] - 15, hell2_center[1] - 15,
hell2_center[0] + 15, hell2_center[1] + 15,
fill='black')
# oval 寶藏
oval_center = origin + UNIT * 2
self.oval = self.canvas.create_oval(
oval_center[0] - 15, oval_center[1] - 15,
oval_center[0] + 15, oval_center[1] + 15,
fill='yellow')
# rect zoe
self.rect = self.canvas.create_rectangle(
origin[0] - 15, origin[1] - 15,
origin[0] + 15, origin[1] + 15,
fill='red')
self.canvas.pack() # 打包
#return(self.canvas)
# zoe復位
def reset(self):
self.update()
time.sleep(0.5)
self.canvas.delete(self.rect)
origin = np.array([20, 20])
self.rect = self.canvas.create_rectangle(
origin[0] - 15, origin[1] - 15,
origin[0] + 15, origin[1] + 15,
fill='red'
)
return self.canvas.coords(self.rect)
# 生成下一狀態和獎懲機制策略
def step(self, action):
s = self.canvas.coords(self.rect)
base_action = np.array([0, 0])
if action == 0: # up
if s[1] > UNIT:
base_action[1] -= UNIT
elif action == 1: # down
if s[1] < (MAZE_H - 1) * UNIT:
base_action[1] += UNIT
elif action == 2: # right
if s[0] < (MAZE_W - 1) * UNIT:
base_action[0] += UNIT
elif action == 3: # left
if s[0] > UNIT:
base_action[0] -= UNIT
self.canvas.move(self.rect, base_action[0], base_action[1])
s_ = self.canvas.coords(self.rect)
# 獎懲判斷
if s_ == self.canvas.coords(self.oval):
reward = 1
done = True
s_ = 'terminal'
elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2)]:
reward = -1
done = True
s_ = 'terminal'
else:
reward = 0
done = False
return s_, reward, done
def render(self):
time.sleep(0.1)
self.update()
# --------------------------------------------------------
# 主函數:程序入口
"""
if __name__ == "__main__":
env = Maze()
env.after(100, update)
env.mainloop()
"""
Part 2. zoe大腦class QLearningTable
import numpy as np
import pandas as pd
class QLearningTable:
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
self.actions = actions
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
# 根據當前觀測狀態 s 選擇動作 a
def choose_action(self, observation):
self.check_state_exist(observation)
if np.random.uniform() < self.epsilon:
state_action = self.q_table.loc[observation, :]
action = np.random.choice(state_action[state_action == np.max(state_action)].index)
else:
action = np.random.choice(self.actions)
return action
# 得到(s, a, r, s_)序列後進行狀態動作值更新
def learn(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, :].max()
else:
q_target = r
self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
def check_state_exist(self, state):
if state not in self.q_table.index:
self.q_table = self.q_table.append(
pd.Series(
[0]*len(self.actions),
index=self.q_table.columns,
name=state
)
)
Part 3. 運行函數
from maze_env import Maze
from RL_brain import QLearningTable
def update():
for episode in range(100):
observation = env.reset()
# print(str(observation))
step_counter = 0
while True:
step_counter += 1
env.render()
action = RL.choose_action(str(observation))
# print(str(observation))
observation_, reward, done = env.step(action)
RL.learn(str(observation), action, reward, str(observation_))
observation = observation_
if done:
break
print('Episode %s: total_steps = %s' % (episode, step_counter))
print('game over')
env.destroy()
# --------------------------------------------------
# 程序入口,main函數
if __name__ == "__main__":
env = Maze()
RL = QLearningTable(actions=list(range(env.n_actions)))
env.after(100, update)
env.mainloop()
Part 4. 演示效果