Sarsa也是基於Q表進行增強學習,與系列(1)、(2)中的區別在於狀態動作值更新的方法。Sarsa中通過創建與Q表同等大小的eligibility_trace矩陣來進行Q表值的更新。zoe這裏自己理解一下eligibility_trace:增大當前動作的權重,拉開當前動作值與其他動作的差距,強化下一步的當前步(也可以說是當前步的上一步)的決策性引導。Morvan在程序運行的結果中解釋說會出現zoe不敢進入地獄而一直在安全小區域徘徊。
Part 1. 迷宮環境class Maze
Part 2. zoe大腦class SarsaLambdaTable
Part 3. 運行函數
Part 4. 演示效果
Part 1. 迷宮環境class Maze
"""
Maze環境與系列(1)相同,沒有發生改變。
"""
import numpy as np
import time
import sys
# 窗口界面庫
if sys.version_info.major == 2:
import Tkinter as tk
else:
import tkinter as tk
UNIT = 40 # 像素
MAZE_H = 4 # 網格高度
MAZE_W = 4 # 網格寬度
class Maze(tk.Tk, object): # 新類繼承父類tk.Tk
def __init__(self):
super(Maze, self).__init__() # super類繼承方法,初始化
self.action_space = ['u', 'd', 'l', 'r']
self.n_actions = len(self.action_space)
self.title('maze')
self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT)) # format格式化函數 geometry分辨率函數
self._build_maze()
# 迷宮佈景
def _build_maze(self):
self.canvas = tk.Canvas(self, bg='white', # tk畫圖組件
height = MAZE_H * UNIT,
width = MAZE_W * UNIT)
for c in range(0, MAZE_W * UNIT, UNIT): # 網格
x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
self.canvas.create_line(x0, y0, x1, y1)
for r in range(0, MAZE_H * UNIT, UNIT):
x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
self.canvas.create_line(x0, y0, x1, y1)
# 原點
origin = np.array([20, 20])
# hell 地獄
hell1_center = origin + np.array([UNIT * 2, UNIT])
self.hell1 = self.canvas.create_rectangle(
hell1_center[0] - 15, hell1_center[1] - 15, # [5, 5]
hell1_center[0] + 15, hell1_center[1] + 15, # [35, 35]
fill='black')
hell2_center = origin + np.array([UNIT, UNIT * 2])
self.hell2 = self.canvas.create_rectangle(
hell2_center[0] - 15, hell2_center[1] - 15,
hell2_center[0] + 15, hell2_center[1] + 15,
fill='black')
# oval 寶藏
oval_center = origin + UNIT * 2
self.oval = self.canvas.create_oval(
oval_center[0] - 15, oval_center[1] - 15,
oval_center[0] + 15, oval_center[1] + 15,
fill='yellow')
# rect zoe
self.rect = self.canvas.create_rectangle(
origin[0] - 15, origin[1] - 15,
origin[0] + 15, origin[1] + 15,
fill='red')
self.canvas.pack() # 打包
#return(self.canvas)
# zoe復位
def reset(self):
self.update()
time.sleep(0.5)
self.canvas.delete(self.rect)
origin = np.array([20, 20])
self.rect = self.canvas.create_rectangle(
origin[0] - 15, origin[1] - 15,
origin[0] + 15, origin[1] + 15,
fill='red'
)
return self.canvas.coords(self.rect)
# 生成下一狀態和獎懲機制策略
def step(self, action):
s = self.canvas.coords(self.rect)
base_action = np.array([0, 0])
if action == 0: # up
if s[1] > UNIT:
base_action[1] -= UNIT
elif action == 1: # down
if s[1] < (MAZE_H - 1) * UNIT:
base_action[1] += UNIT
elif action == 2: # right
if s[0] < (MAZE_W - 1) * UNIT:
base_action[0] += UNIT
elif action == 3: # left
if s[0] > UNIT:
base_action[0] -= UNIT
self.canvas.move(self.rect, base_action[0], base_action[1])
s_ = self.canvas.coords(self.rect)
# 獎懲判斷
if s_ == self.canvas.coords(self.oval):
reward = 1
done = True
s_ = 'terminal'
elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2)]:
reward = -1
done = True
s_ = 'terminal'
else:
reward = 0
done = False
return s_, reward, done
def render(self):
time.sleep(0.1)
self.update()
# --------------------------------------------------------
# 主函數:程序入口
"""
if __name__ == "__main__":
env = Maze()
env.after(100, update)
env.mainloop()
"""
Part 2. zoe大腦class SarsaLambdaTable
import numpy as np
import pandas as pd
class RL(object):
def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
self.actions = action_space
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
def check_state_exist(self, state):
if state not in self.q_table.index:
self.q_table = self.q_table.append(
pd.Series(
[0]*len(self.actions),
index=self.q_table.columns,
name=state,
)
)
# 根據當前觀測狀態 s 選擇動作 a
def choose_action(self, observation):
self.check_state_exist(observation)
if np.random.rand() < self.epsilon:
state_action = self.q_table.loc[observation, :]
action = np.random.choice(state_action[state_action == np.max(state_action)].index)
else:
action = np.random.choice(self.actions)
return action
def learn(self, *args):
pass
"""
# off-policy
class QLearningTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
# 得到(s, a, r, s_)序列後進行狀態動作值更新
def learn(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, :].max()
else:
q_target = r
self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
#return self.q_table
"""
"""
# on-policy
class SarsaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_]
else:
q_target = r
self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
"""
class SarsaLambdaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
self.lambda_ = trace_decay
self.eligibility_trace = self.q_table.copy()
def check_state_exist(self, state):
if state not in self.q_table.index:
to_be_append = pd.Series(
[0] * len(self.actions),
index = self.q_table.columns,
name = state
)
self.q_table = self.q_table.append(to_be_append)
self.eligibility_trace = self.eligibility_trace.append(to_be_append)
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_]
else:
q_target = r
error = q_target - q_predict
# increase trace amount for visited state-action pair
# Method 1:
#self.eligibility_trace.loc[s, a] += 1
# Method 2:
self.eligibility_trace.loc[s, :] *= 0
self.eligibility_trace.loc[s, a] = 1
self.q_table += self.lr * error * self.eligibility_trace
# decay eligibility trace after update
self.eligibility_trace *= self.gamma*self.lambda_
Part 3. 運行函數
from maze_env import Maze
#from RL_brain import SarsaTable
from RL_brain import SarsaLambdaTable
def update():
for episode in range(5):
observation = env.reset()
action = RL.choose_action(str(observation))
# print(str(observation))
RL.eligibility_trace *= 0
step_counter = 0
while True:
step_counter += 1
env.render()
observation_, reward, done = env.step(action)
action_ = RL.choose_action(str(observation_))
#q_table = RL.learn(str(observation), action, reward, str(observation_), action_)
RL.learn(str(observation), action, reward, str(observation_), action_)
observation = observation_
action = action_
if done:
break
print('Episode %s: total_steps = %s' % (episode, step_counter))
print('game over')
env.destroy()
return (RL.q_table, RL.eligibility_trace)
if __name__ == "__main__":
env = Maze()
#RL = SarsaTable(actions=list(range(env.n_actions)))
RL = SarsaLambdaTable(actions=list(range(env.n_actions)))
print('\r\nQ-table開始:\n')
print(RL.q_table)
print('\r\neligibility_trace開始:\n')
print(RL.eligibility_trace)
env.after(100, update)
env.mainloop()
print('\r\nQ-table結束:\n')
print(RL.q_table)
print('\r\neligibility_trace結束:\n')
print(RL.eligibility_trace)