Q-learing 算法思想
21世紀20年代的第一個春節快到了,給大家拜個早年,祝大家春節快樂。雖然對已經沒有寒假的我來說,過年的期盼沒有之前那麼大,但是還是有所期待的,因爲還有那麼一丟丟年終獎值得期待。在一年的工作中,有過奮鬥,有過彷徨,有過摸魚,這一切都會在年終有所體現。這一年經過努力,經過懶惰變換了很多的狀態。在這一系列的狀態轉換後,我也是很希望年終有一個好的結果。假如這一年進行時間離散化,每一個時間做出一個動作,進行狀態轉換,都會有相應的價值相應。那麼我希望所有轉換的總和價值最高。這就是強化學習的Q-learning。
設有狀態集合和動作狀態集合,在狀態採用動作會換到一個新的狀態獲得獎罰.假設一年的某個時段,我處在摸魚的階段,然後採取了加班努力工作的工作,老闆看到後給我加薪。或者我努力工作狀態時,突然心情不好,消極怠工,一下變成摸魚狀態,老闆看到後給我降薪。這就是Q-learning的狀態動作轉換意義。自然地,我希望一年的工作獲得最大的收穫。一年的收穫可以表示爲.
當處在某個狀態時,根據貝爾曼最優化原理,經過(i,j)的從(i0,j0)到(if,jf)的最優路徑是一條由原點(0,0)到結點(i,j)到終點(if,jf)的最優路徑的串聯。狀態轉移可以用遞歸來表示。
下面我們對的表達式實現採用一個例子說明。
考慮一個房間如上圖,5爲最終的目的地。用圖的結果表示如下
各個房間的轉換的reward矩陣如下表示:
則價值函數的表達式記爲1.1
R(state, action) + Gamma * Max[Q(next state, all actions)]
Q-learing算法的python實現
# -*- coding: utf-8 -*-
# @Time : 2020/1/10 9:35
# @Author : HelloWorld!
# @FileName: q_room.py
# @Software: PyCharm
# @Operating System: Windows 10
# @Python.version: 3.6
import numpy as np
import random
# 初始化矩陣
Q = np.zeros((6, 6))
Q = np.matrix(Q)
# 回報矩陣R
R = np.matrix([[-1, -1, -1, -1, 0, -1], [-1, -1, -1, 0, -1, 100], [-1, -1, -1, 0, -1, -1], [-1, 0, 0, -1, 0, -1],
[0, -1, -1, 0, -1, 100], [-1, 0, -1, -1, 0, 100]])
# 設立學習參數
γ = 0.8
# 訓練
for i in range(2000):
# 對每一個訓練,隨機選擇一種狀態
state = random.randint(0, 5)
while True:
# 選擇當前狀態下的所有可能動作
r_pos_action = []
for action in range(6):
if R[state, action] >= 0:
r_pos_action.append(action)
next_state = r_pos_action[random.randint(0, len(r_pos_action) - 1)]
Q[state, next_state] = R[state, next_state] + γ * (Q[next_state]).max() # 更新
state = next_state
if i%100==0:
print('---------------------',i)
print(Q)
# 狀態4位最優庫存狀態
if state == 5:
break
print(Q)
結論
Q-learning 是強化學習的最基本的算法。當動作有限,狀態很多的時候,可以用深度神經網絡來才設計價值函數。因此主要映射好了狀態,動作,價值函數就可以採用強化學習來解決問題。
附一個隨機探索從某點到目的地的路徑選擇例子。參考https://blog.csdn.net/shankezh/article/details/102864085
需要找三個png圖片,命名爲boom.png ,diamond.png ,player.png 放在Q_learing.py和 Q_test.py的同級目錄之下。
# -*- coding: utf-8 -*-
# @Time : 2020/1/8 14:34
# @Author : HelloWorld!
# @FileName: Q_learing.py
# @Software: PyCharm
# @Operating System: Windows 10
# @Python.version: 3.6
import tkinter as tk
from PIL import ImageTk
from PIL import Image
import time
class Env:
def __init__(self):
self.grid_size = 100
self.win = tk.Tk()
self.pic_player, self.pic_diamond, self.pic_boom1, self.pic_boom2, self.pic_boom3, self.pic_boom4 = self.__load_img()
self.__init_win()
self.canvas = self.__init_rc()
self.texts = self.__produce_text()
self.canvas.pack()
# self._init_test_case()
# self.win.mainloop()
def __init_win(self):
self.win.title('Grid World')
# self.win.geometry("500x300")
def __init_rc(self):
canvas = tk.Canvas(self.win, width=500, height=720, bg='white')
for h in range(5):
for v in range(5):
canvas.create_rectangle(self.grid_size * v, self.grid_size * h, self.grid_size * (v + 1),
self.grid_size * (h + 1))
trans_pixel = int(self.grid_size / 2)
self.player = canvas.create_image(trans_pixel + self.grid_size * 0, trans_pixel + self.grid_size * 0,
image=self.pic_player)
self.diamond = canvas.create_image(trans_pixel + self.grid_size * 4, trans_pixel + self.grid_size * 4,
image=self.pic_diamond)
self.boom1 = canvas.create_image(trans_pixel + self.grid_size * 1, trans_pixel + self.grid_size * 1,
image=self.pic_boom1)
self.boom2 = canvas.create_image(trans_pixel + self.grid_size * 3, trans_pixel + self.grid_size * 1,
image=self.pic_boom2)
self.boom3 = canvas.create_image(trans_pixel + self.grid_size * 1, trans_pixel + self.grid_size * 3,
image=self.pic_boom3)
self.boom4 = canvas.create_image(trans_pixel + self.grid_size * 3, trans_pixel + self.grid_size * 3,
image=self.pic_boom4)
return canvas
def __load_img(self):
pic_resize = int(self.grid_size / 2)
player = ImageTk.PhotoImage(Image.open("player.png").resize((pic_resize, pic_resize)))
diamond = ImageTk.PhotoImage(Image.open("diamond.png").resize((pic_resize, pic_resize)))
boom1 = ImageTk.PhotoImage(Image.open('boom.png').resize((pic_resize, pic_resize)))
boom2 = ImageTk.PhotoImage(Image.open('boom.png').resize((pic_resize, pic_resize)))
boom3 = ImageTk.PhotoImage(Image.open('boom.png').resize((pic_resize, pic_resize)))
boom4 = ImageTk.PhotoImage(Image.open('boom.png').resize((pic_resize, pic_resize)))
return player, diamond, boom1, boom2, boom3, boom4
def __produce_text(self):
texts = []
x = self.grid_size / 2
y = self.grid_size / 6
for h in range(5):
for v in range(5):
up = self.canvas.create_text(x + h * self.grid_size, y + v * self.grid_size, text=0)
down = self.canvas.create_text(x + h * self.grid_size, self.grid_size - y + v * self.grid_size, text=0)
left = self.canvas.create_text(y + h * self.grid_size, x + v * self.grid_size, text=0)
right = self.canvas.create_text(self.grid_size - y + h * self.grid_size, x + v * self.grid_size, text=0)
texts.append({"up": up, "down": down, "left": left, "right": right})
return texts
def _win_d_update(self):
self.win.update()
time.sleep(0.1)
class GridWorld(Env):
def __init__(self):
super().__init__()
self._win_d_update()
def player_move(self, x, y):
# x橫向移動向右,y縱向移動向下
self.canvas.move(self.player, x * self.grid_size, y * self.grid_size)
self._win_d_update()
def reset(self):
# 重置爲起始位置
x, y = self.canvas.coords(self.player)
self.canvas.move(self.player, -x + self.grid_size / 2, -y + self.grid_size / 2)
self._win_d_update()
return self.get_state(self.player)
def get_state(self, who):
x, y = self.canvas.coords(who)
state = [int(x / self.grid_size), int(y / self.grid_size)]
return state
def update_val(self, num, arrow, val):
pos = num[0] * 5 + num[1]
x, y = self.canvas.coords(self.texts[pos][arrow])
self.canvas.delete(self.texts[pos][arrow])
self.texts[pos][arrow] = self.canvas.create_text(x, y, text=val)
# self._win_d_update()
def exec_calc(self, action):
# 執行一次決策
feedback = 'alive' # alive, stop, dead 分別對應通過,撞牆,炸死
next_state = []
next_h, next_v, reward = 0.0, 0.0, 0.0
h, v = self.get_state(self.player)
if action == 0: # up
next_h = h
next_v = v - 1
# self.player_move(0, -1)
elif action == 1: # down
next_h = h
next_v = v + 1
# self.player_move(0, 1)
elif action == 2: # left
next_h = h - 1
next_v = v
# self.player_move(-1, 0)
elif action == 3: # right
next_h = h + 1
next_v = v
# self.player_move(1, 0)
else:
print('programmer bug ...')
next_state = [next_h, next_v]
boom1, boom2, boom3, boom4 = self.get_state(self.boom1), self.get_state(self.boom2), self.get_state(
self.boom3), self.get_state(self.boom4)
diamond = self.get_state(self.diamond)
if next_h < 0 or next_v < 0 or next_h > 4 or next_v > 4: # 超過邊界
reward = -1
feedback = 'stop'
elif next_state == boom1 or next_state == boom2 or next_state == boom3 or next_state == boom4: # 炸彈區域
reward = -100
feedback = 'dead'
elif next_state == diamond: # 獲得的通關物品
reward = 500
else:
reward = 0
return feedback, next_state, reward
def update_view(self, state, action, next_state, q_val):
action_list = ['up', 'down', 'left', 'right']
self.player_move(next_state[0] - state[0], next_state[1] - state[1])
self.update_val(state, action_list[action], round(q_val, 2))
def attach(self):
# 到達終點,返回True , 未到達,返回False
return str(self.get_state(self.player)) == str(self.get_state(self.diamond))
# -*- coding: utf-8 -*-
# @Time : 2020/1/8 14:35
# @Author : HelloWorld!
# @FileName: Q_test.py
# @Software: PyCharm
# @Operating System: Windows 10
# @Python.version: 3.6
import numpy as np
import Q_learing
class Agent:
def __init__(self):
self.actions = [0, 1, 2, 3] # up down left right
self.q_table = dict()
self.__init_q_table()
self.epsilon = 0.1
self.learning_rate = 0.1
self.gamma = 0.8
# print(self.q_table)
def __init_q_table(self):
for v in range(5):
for h in range(5):
self.q_table[str([h, v])] = [0.0, 0.0, 0.0, 0.0]
def get_action(self, state):
# 根據狀態選取下一個動作,但不對無法通過的區域進行選取
action_list = self.q_table[str(state)]
pass_action_index = []
for index, val in enumerate(action_list):
if val >= 0:
pass_action_index.append(index)
# 使用epsilon greedy來進行動作選取
if np.random.rand() <= self.epsilon:
# 進行探索
return np.random.choice(pass_action_index)
else:
# 直接選取q最大值
max_val = action_list[pass_action_index[0]]
max_list = []
for i in pass_action_index:
# 最大值相同且不止一個則隨機選個最大值
if max_val < action_list[i]:
max_list.clear()
max_val = action_list[i]
max_list.append(i)
elif max_val == action_list[i]:
max_list.append(i)
return np.random.choice(max_list)
def update_q_table(self, feedback, state, action, reward, next_state):
# Q(s,a) = Q(s,a) + lr * { reward + gamma * max[Q(s`,a`)] - Q(s,a) }
q_s_a = self.q_table[str(state)][action] # 取出對應當前狀態動作的q值
if feedback == 'stop':
q_ns_a = 0 # 撞牆時不存在下一狀態,屬於原地不變
else:
q_ns_a = np.max(self.q_table[str(next_state)])
# 貝爾曼方程更新
# self.q_table[str(state)][action] = q_s_a + self.learning_rate * (
# reward + self.gamma * q_ns_a - q_s_a
# )
self.q_table[str(state)][action] = (1 - self.learning_rate) * q_s_a + self.learning_rate * (
reward + self.gamma * q_ns_a)
# print(self.q_table)
return self.q_table[str(state)][action]
if __name__ == '__main__':
np.random.seed(0)
env = Q_learing.GridWorld()
agent = Agent()
for ep in range(2000):
if ep < 100:
agent.epsilon = 0.2
else:
agent.epsilon = 0.1
state = env.reset()
print('第{}輪訓練開始 ... '.format(ep + 1))
while not env.attach():
action = agent.get_action(state) # 產生動作
# print(action)
feedback, next_state, reward = env.exec_calc(action) # 計算狀態
q_val = agent.update_q_table(feedback, state, action, reward, next_state) # 更新Q表
if feedback == 'stop':
env.update_view(state, action, state, q_val)
continue
elif feedback == 'dead':
env.update_view(state, action, next_state, q_val)
break
else:
env.update_view(state, action, next_state, q_val)
state = next_state # 狀態改變