作爲一個非專業初學愛好者,在看了一些強化學習教程之後決定從Q-table入門強化學習。我參考的資料很多,個人感覺下邊這個鏈接https://mp.weixin.qq.com/s/34E1tEQMZuaxvZA66_HRwA講的不錯。之前接觸過Q-table的簡單理論,但是一直沒有實踐一下,一寫代碼才發現很多問題其實自己沒有考慮清楚。現在附上一份剛寫不久的Q-table代碼。詳細原理不再多說,代碼中不懂的具體可以看註釋。
import numpy as np
from time import sleep
# Q_table的更新順序其實是倒着更新,離終點越近的會先更新,然後由更新公式一點一點將接近初始點的Q值更新
class Q_table():
def __init__(self):
self.table = np.zeros([4, 7, 10]) # [X,_,_] X=0:上 X=1:右 X=2:下 X=3:左
self.table[0, 0, :] = -99 # 超出邊界的動作獎勵設置很小
self.table[1, :, 9] = -99
self.table[2, 6, :] = -99
self.table[3, :, 0] = -99
self.offset = 0 # 測試獎勵值在不同範圍的情況時使用的,不是必要使用
self.Reward = np.array([ # 獎勵值的設定很重要!!!!
[0, 0, 0, 0, -1, 0, 0, 0, 0, 0],
[0, -1, -1, 0, 0, 0, 0, 0, -1, 0],
[0, 0, 0, -1, 0, 0, -1, 0, 0, 0],
[-1, 0, 0, 0, 0, -1, 0, 0, 0, 0],
[0, -1, 0, -1, -1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, -1, 1, -1, 0, 0],
[0, 0, 0, -1, 0, 0, 0, 0, 0, 0],
])-self.offset
print(self.Reward)
self.cur_y, self.cur_x = 0, 0 # 當前座標
self.lr = 0.8 # 學習率
self.discount = 0.8 # 折扣率
def update_Q_table(self, pos_y, pos_x, action): # 更新Q表的方法,輸入當前位置座標、要採取的動作。
self.cur_y, self.cur_x = pos_y, pos_x
next_y, next_x = self.cur_y, self.cur_x # 初始化next_y和next_x
update_flag = False
# 選取對應的action
if action == 0:
if self.cur_y > 0: # 保證next_y不會超出數組範圍,否則無效
next_y, next_x = self.cur_y - 1, self.cur_x
update_flag = True
elif action == 1:
if self.cur_x < 9:# 保證next_x不會超出數組範圍,否則無效
next_y, next_x = self.cur_y, self.cur_x + 1
update_flag = True
elif action == 2:
if self.cur_y < 6:
next_y, next_x = self.cur_y + 1, self.cur_x
update_flag = True
elif action == 3:
if self.cur_x > 0:
next_y, next_x = self.cur_y, self.cur_x - 1
update_flag = True
# 如果採取的action有效的話
if update_flag == True:
# 假設原來是在(x0,y0)(本程序(x0,y0)就是(cur_x,cur_y)),
# 執行action後是在(x1,y1)(本程序(x1,y1)就是(next_x,next_y),
# 就把(x0,y0)執行action後即在(x1,y1)處的4個可能的action的Q值保存
next_pos_all_Q_actions_list = []
# if next_y >= 0: # 保證(x1,y1)的4個可能的action值不會越界,不用寫了。
# next_pos_all_Q_actions_list.append(self.table[0, next_y, next_x])
# else:
# next_pos_all_Q_actions_list.append(-99)
# if next_x <= 9:
# next_pos_all_Q_actions_list.append(self.table[1, next_y, next_x])
# else:
# next_pos_all_Q_actions_list.append(-99)
# if next_y <= 6:
# next_pos_all_Q_actions_list.append(self.table[2, next_y, next_x])
# else:
# next_pos_all_Q_actions_list.append(-99)
# if next_y >= 0:
# next_pos_all_Q_actions_list.append(self.table[3, next_y, next_x])
# else:
# next_pos_all_Q_actions_list.append(-99)
for i in range(4):
next_pos_all_Q_actions_list.append(self.table[i, next_y, next_x]) # 在(x1,y1)處的4個可採取動作的Q值
next_pos_all_Q_actions = np.array(next_pos_all_Q_actions_list)
max_next_pos_Q_val = np.max(next_pos_all_Q_actions) # 找那4個Q值的最大值
next_action = np.argmax(next_pos_all_Q_actions) # 找那4個Q值的最大值的位置,就得到了相應的action
delta_Q = self.Reward[next_y, next_x] + self.discount * max_next_pos_Q_val \
- self.table[action, self.cur_y, self.cur_x]
self.table[action, self.cur_y, self.cur_x] += self.lr * delta_Q # 更新Q
def show_actions(self): # 演示執行動作
pos_y, pos_x = 0, 0
whole_map = np.zeros([7, 10])
whole_map[pos_y, pos_x] = 1
for i in range(20):
next_action = np.argmax([self.table[0, pos_y, pos_x],
self.table[1, pos_y, pos_x],
self.table[2, pos_y, pos_x],
self.table[3, pos_y, pos_x]])
if next_action == 0:
pos_y -= 1
elif next_action == 1:
pos_x += 1
elif next_action == 2:
pos_y += 1
elif next_action == 3:
pos_x -= 1
# whole_map = np.zeros([7, 10])
whole_map[pos_y, pos_x] = 1
print('='*40)
print(whole_map)
sleep(0.5)
if self.Reward[pos_y,pos_x]==1-self.offset:
break
qtable = Q_table()
n = 0
while True:
for x in range(10):
for y in range(7):
for a in range(4):
qtable.update_Q_table(y, x, a)
n += 1
if n == 200: # 訓練200次
break
print(qtable.table)
qtable.show_actions()
歡迎評論。