強化學習-Q-learing算法原理與實現

Q-learing 算法思想

21世紀20年代的第一個春節快到了,給大家拜個早年,祝大家春節快樂。雖然對已經沒有寒假的我來說,過年的期盼沒有之前那麼大,但是還是有所期待的,因爲還有那麼一丟丟年終獎值得期待。在一年的工作中,有過奮鬥,有過彷徨,有過摸魚,這一切都會在年終有所體現。這一年經過努力,經過懶惰變換了很多的狀態。在這一系列的狀態轉換後,我也是很希望年終有一個好的結果。假如這一年進行時間離散化,每一個時間做出一個動作,進行狀態轉換,都會有相應的價值相應。那麼我希望所有轉換的總和價值最高。這就是強化學習的Q-learning。

設有狀態集合SS和動作狀態集合AA,在狀態sts_t採用動作ata_t會換到一個新的狀態st+1s_{t+1}獲得獎罰rt+1r_{t+1}.假設一年的某個時段,我處在摸魚的階段,然後採取了加班努力工作的工作,老闆看到後給我加薪。或者我努力工作狀態時,突然心情不好,消極怠工,一下變成摸魚狀態,老闆看到後給我降薪。這就是Q-learning的狀態動作轉換意義。自然地,我希望一年的工作獲得最大的收穫。一年的收穫可以表示爲t=1Trt\sum_{t=1}^T r_t.

當處在某個狀態時,根據貝爾曼最優化原理,經過(i,j)的從(i0,j0)到(if,jf)的最優路徑是一條由原點(0,0)到結點(i,j)到終點(if,jf)的最優路徑的串聯。狀態轉移可以用遞歸來表示。
Qnext=Qnow(st,at)+bβQ_{next}=Q_{now}(s_t,a_t)+b\betaβ=rt+1+γmaxQ(st+1,a)\beta=r_{t+1}+\gamma maxQ(s_{t+1},a)
下面我們對β\beta的表達式實現採用一個例子說明。

在這裏插入圖片描述
考慮一個房間如上圖,5爲最終的目的地。用圖的結果表示如下
在這裏插入圖片描述
各個房間的轉換的reward矩陣如下表示:
在這裏插入圖片描述
則價值函數的表達式記爲1.1
R(state, action) + Gamma * Max[Q(next state, all actions)]

Q-learing算法的python實現

# -*- coding: utf-8 -*-
# @Time    : 2020/1/10 9:35
# @Author  : HelloWorld!
# @FileName: q_room.py
# @Software: PyCharm
# @Operating System: Windows 10
# @Python.version: 3.6

import numpy as np
import random

# 初始化矩陣
Q = np.zeros((6, 6))
Q = np.matrix(Q)

# 回報矩陣R

R = np.matrix([[-1, -1, -1, -1, 0, -1], [-1, -1, -1, 0, -1, 100], [-1, -1, -1, 0, -1, -1], [-1, 0, 0, -1, 0, -1],
               [0, -1, -1, 0, -1, 100], [-1, 0, -1, -1, 0, 100]])

# 設立學習參數
γ = 0.8
# 訓練

for i in range(2000):

    # 對每一個訓練,隨機選擇一種狀態

    state = random.randint(0, 5)

    while True:

        # 選擇當前狀態下的所有可能動作

        r_pos_action = []

        for action in range(6):

            if R[state, action] >= 0:
                r_pos_action.append(action)

        next_state = r_pos_action[random.randint(0, len(r_pos_action) - 1)]

        Q[state, next_state] = R[state, next_state] + γ * (Q[next_state]).max()  # 更新

        state = next_state
        if i%100==0:
            print('---------------------',i)
            print(Q)
        # 狀態4位最優庫存狀態

        if state == 5:
            break

print(Q)

結論

Q-learning 是強化學習的最基本的算法。當動作有限,狀態很多的時候,可以用深度神經網絡來才設計價值函數。因此主要映射好了狀態,動作,價值函數就可以採用強化學習來解決問題。
附一個隨機探索從某點到目的地的路徑選擇例子。參考https://blog.csdn.net/shankezh/article/details/102864085

需要找三個png圖片,命名爲boom.png ,diamond.png ,player.png 放在Q_learing.py和 Q_test.py的同級目錄之下。

# -*- coding: utf-8 -*-
# @Time    : 2020/1/8 14:34
# @Author  : HelloWorld!
# @FileName: Q_learing.py
# @Software: PyCharm
# @Operating System: Windows 10
# @Python.version: 3.6


import tkinter as tk

from PIL import ImageTk

from PIL import Image

import time


class Env:

    def __init__(self):

        self.grid_size = 100

        self.win = tk.Tk()

        self.pic_player, self.pic_diamond, self.pic_boom1, self.pic_boom2, self.pic_boom3, self.pic_boom4 = self.__load_img()

        self.__init_win()

        self.canvas = self.__init_rc()

        self.texts = self.__produce_text()

        self.canvas.pack()

        # self._init_test_case()

        # self.win.mainloop()

    def __init_win(self):

        self.win.title('Grid World')

        # self.win.geometry("500x300")

    def __init_rc(self):

        canvas = tk.Canvas(self.win, width=500, height=720, bg='white')

        for h in range(5):

            for v in range(5):
                canvas.create_rectangle(self.grid_size * v, self.grid_size * h, self.grid_size * (v + 1),
                                        self.grid_size * (h + 1))

        trans_pixel = int(self.grid_size / 2)

        self.player = canvas.create_image(trans_pixel + self.grid_size * 0, trans_pixel + self.grid_size * 0,
                                          image=self.pic_player)

        self.diamond = canvas.create_image(trans_pixel + self.grid_size * 4, trans_pixel + self.grid_size * 4,
                                           image=self.pic_diamond)

        self.boom1 = canvas.create_image(trans_pixel + self.grid_size * 1, trans_pixel + self.grid_size * 1,
                                         image=self.pic_boom1)

        self.boom2 = canvas.create_image(trans_pixel + self.grid_size * 3, trans_pixel + self.grid_size * 1,
                                         image=self.pic_boom2)

        self.boom3 = canvas.create_image(trans_pixel + self.grid_size * 1, trans_pixel + self.grid_size * 3,
                                         image=self.pic_boom3)

        self.boom4 = canvas.create_image(trans_pixel + self.grid_size * 3, trans_pixel + self.grid_size * 3,
                                         image=self.pic_boom4)

        return canvas

    def __load_img(self):

        pic_resize = int(self.grid_size / 2)

        player = ImageTk.PhotoImage(Image.open("player.png").resize((pic_resize, pic_resize)))

        diamond = ImageTk.PhotoImage(Image.open("diamond.png").resize((pic_resize, pic_resize)))

        boom1 = ImageTk.PhotoImage(Image.open('boom.png').resize((pic_resize, pic_resize)))

        boom2 = ImageTk.PhotoImage(Image.open('boom.png').resize((pic_resize, pic_resize)))

        boom3 = ImageTk.PhotoImage(Image.open('boom.png').resize((pic_resize, pic_resize)))

        boom4 = ImageTk.PhotoImage(Image.open('boom.png').resize((pic_resize, pic_resize)))

        return player, diamond, boom1, boom2, boom3, boom4

    def __produce_text(self):

        texts = []

        x = self.grid_size / 2

        y = self.grid_size / 6

        for h in range(5):

            for v in range(5):
                up = self.canvas.create_text(x + h * self.grid_size, y + v * self.grid_size, text=0)

                down = self.canvas.create_text(x + h * self.grid_size, self.grid_size - y + v * self.grid_size, text=0)

                left = self.canvas.create_text(y + h * self.grid_size, x + v * self.grid_size, text=0)

                right = self.canvas.create_text(self.grid_size - y + h * self.grid_size, x + v * self.grid_size, text=0)

                texts.append({"up": up, "down": down, "left": left, "right": right})

        return texts

    def _win_d_update(self):

        self.win.update()

        time.sleep(0.1)


class GridWorld(Env):

    def __init__(self):

        super().__init__()

        self._win_d_update()

    def player_move(self, x, y):

        # x橫向移動向右,y縱向移動向下

        self.canvas.move(self.player, x * self.grid_size, y * self.grid_size)

        self._win_d_update()

    def reset(self):

        # 重置爲起始位置

        x, y = self.canvas.coords(self.player)

        self.canvas.move(self.player, -x + self.grid_size / 2, -y + self.grid_size / 2)

        self._win_d_update()

        return self.get_state(self.player)

    def get_state(self, who):

        x, y = self.canvas.coords(who)

        state = [int(x / self.grid_size), int(y / self.grid_size)]

        return state

    def update_val(self, num, arrow, val):

        pos = num[0] * 5 + num[1]

        x, y = self.canvas.coords(self.texts[pos][arrow])

        self.canvas.delete(self.texts[pos][arrow])

        self.texts[pos][arrow] = self.canvas.create_text(x, y, text=val)

        # self._win_d_update()

    def exec_calc(self, action):

        # 執行一次決策

        feedback = 'alive'  # alive, stop, dead 分別對應通過,撞牆,炸死

        next_state = []

        next_h, next_v, reward = 0.0, 0.0, 0.0

        h, v = self.get_state(self.player)

        if action == 0:  # up

            next_h = h

            next_v = v - 1

            # self.player_move(0, -1)

        elif action == 1:  # down

            next_h = h

            next_v = v + 1

            # self.player_move(0, 1)

        elif action == 2:  # left

            next_h = h - 1

            next_v = v

            # self.player_move(-1, 0)

        elif action == 3:  # right

            next_h = h + 1

            next_v = v

            # self.player_move(1, 0)

        else:

            print('programmer bug ...')

        next_state = [next_h, next_v]

        boom1, boom2, boom3, boom4 = self.get_state(self.boom1), self.get_state(self.boom2), self.get_state(

            self.boom3), self.get_state(self.boom4)

        diamond = self.get_state(self.diamond)

        if next_h < 0 or next_v < 0 or next_h > 4 or next_v > 4:  # 超過邊界

            reward = -1

            feedback = 'stop'

        elif next_state == boom1 or next_state == boom2 or next_state == boom3 or next_state == boom4:  # 炸彈區域

            reward = -100

            feedback = 'dead'

        elif next_state == diamond:  # 獲得的通關物品

            reward = 500

        else:

            reward = 0

        return feedback, next_state, reward

    def update_view(self, state, action, next_state, q_val):

        action_list = ['up', 'down', 'left', 'right']

        self.player_move(next_state[0] - state[0], next_state[1] - state[1])

        self.update_val(state, action_list[action], round(q_val, 2))

    def attach(self):

        # 到達終點,返回True , 未到達,返回False

        return str(self.get_state(self.player)) == str(self.get_state(self.diamond))

# -*- coding: utf-8 -*-
# @Time    : 2020/1/8 14:35
# @Author  : HelloWorld!
# @FileName: Q_test.py
# @Software: PyCharm
# @Operating System: Windows 10
# @Python.version: 3.6


import numpy as np

import Q_learing


class Agent:

    def __init__(self):

        self.actions = [0, 1, 2, 3]  # up down left right

        self.q_table = dict()

        self.__init_q_table()

        self.epsilon = 0.1

        self.learning_rate = 0.1

        self.gamma = 0.8

        # print(self.q_table)

    def __init_q_table(self):

        for v in range(5):

            for h in range(5):
                self.q_table[str([h, v])] = [0.0, 0.0, 0.0, 0.0]

    def get_action(self, state):

        # 根據狀態選取下一個動作,但不對無法通過的區域進行選取

        action_list = self.q_table[str(state)]

        pass_action_index = []

        for index, val in enumerate(action_list):

            if val >= 0:
                pass_action_index.append(index)

        # 使用epsilon greedy來進行動作選取

        if np.random.rand() <= self.epsilon:

            # 進行探索

            return np.random.choice(pass_action_index)

        else:

            # 直接選取q最大值

            max_val = action_list[pass_action_index[0]]

            max_list = []

            for i in pass_action_index:

                # 最大值相同且不止一個則隨機選個最大值

                if max_val < action_list[i]:

                    max_list.clear()

                    max_val = action_list[i]

                    max_list.append(i)

                elif max_val == action_list[i]:

                    max_list.append(i)

            return np.random.choice(max_list)

    def update_q_table(self, feedback, state, action, reward, next_state):

        # Q(s,a) = Q(s,a) + lr * { reward + gamma * max[Q(s`,a`)] - Q(s,a) }

        q_s_a = self.q_table[str(state)][action]  # 取出對應當前狀態動作的q值

        if feedback == 'stop':

            q_ns_a = 0  # 撞牆時不存在下一狀態,屬於原地不變

        else:

            q_ns_a = np.max(self.q_table[str(next_state)])

        # 貝爾曼方程更新

        # self.q_table[str(state)][action] = q_s_a + self.learning_rate * (

        #     reward + self.gamma * q_ns_a - q_s_a

        # )

        self.q_table[str(state)][action] = (1 - self.learning_rate) * q_s_a + self.learning_rate * (
                    reward + self.gamma * q_ns_a)

        # print(self.q_table)

        return self.q_table[str(state)][action]


if __name__ == '__main__':

    np.random.seed(0)

    env = Q_learing.GridWorld()

    agent = Agent()

    for ep in range(2000):

        if ep < 100:

            agent.epsilon = 0.2

        else:

            agent.epsilon = 0.1

        state = env.reset()

        print('第{}輪訓練開始 ... '.format(ep + 1))

        while not env.attach():

            action = agent.get_action(state)  # 產生動作

            # print(action)

            feedback, next_state, reward = env.exec_calc(action)  # 計算狀態

            q_val = agent.update_q_table(feedback, state, action, reward, next_state)  # 更新Q表

            if feedback == 'stop':

                env.update_view(state, action, state, q_val)

                continue

            elif feedback == 'dead':

                env.update_view(state, action, next_state, q_val)

                break

            else:

                env.update_view(state, action, next_state, q_val)

            state = next_state  # 狀態改變

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章