【Python】強化學習Q-Learning走迷宮

Q-Learning是一種基於值函數的強化學習算法,這裏用該算法解決走迷宮問題。

算法步驟如下:

1. 初始化 Q 表:每個表格對應狀態動作的 Q 值。這裏就是一個H*W*4的表,4代表上下左右四個動作。

2. 選擇動作: 根據 Q 表格選擇最優動作或者以一定概率隨機選擇動作。

3. 執行動作,得到返回獎勵(這裏需要自定義,比如到達目標給的大的reward,撞牆給個小的reward)和下一個狀態。

4. 更新 Q 表: 根據規則更新 Q 表格中對應狀態動作的 Q 值。規則爲 Q(s, a) = Q(s, a) + α*[r + γ*max(Q(s', a')) - Q(s, a)],其中 α 是學習率,γ 是折扣因子,r 是獲得的獎勵,s 是當前狀態,a 是當前動作,s' 是下一個狀態,a' 是在下一個狀態下選擇的最優動作。

5. 重複步驟 2-4: 不斷與環境交互,選擇動作、執行、更新 Q 值,直至滿足停止條件(如達到最大迭代次數或者 Q 值收斂等)。

6. 最優策略提取: 通過學習得到的 Q 表格,可以提取最優策略,即在每個狀態下選擇具有最高 Q 值的動作。

代碼如下:

import numpy as np
import random
import matplotlib.pyplot as plt
from PIL import Image
import imageio
import io

H = 30
W = 40

start = (0, random.randint(0, H-1))
goal = (W-1, random.randint(0, H-1))

img = Image.new('RGB', (W, H), (255, 255, 255))
pixels = img.load()

maze = np.zeros((W, H))
for h in range(H):
    for w in range(W):
        if random.random() < 0.1:
            maze[w, h] = -1

actions_num = 4
actions = [0, 1, 2, 3]
q_table = np.zeros((W, H, actions_num))
rate = 0.5
factor = 0.9
images = []

for i in range(2000):

    state = start
    path = [start]
    while(True):

        if np.random.rand() < 0.1:              #隨機或者下一個狀態最大q值對應的動作
            action = np.random.choice(actions)
        else:
            action = np.argmax(q_table[state])

        next_state = None                       #執行該動作
        if action == 0 and state[0] > 0:
            next_state = (state[0]-1, state[1])
        elif action == 1 and state[0] < W-1:
            next_state = (state[0]+1, state[1])
        elif action == 2 and state[1] > 0:
            next_state = (state[0], state[1]-1)
        elif action == 3 and state[1] < H-1:
            next_state = (state[0], state[1]+1)
        else:
            next_state = state

        if next_state == goal:                  #得到reward,到目標給大正反饋
            reward = 100
        elif maze[next_state] == -1:
            reward = -100                       #遇見障礙物給大負反饋
        else:
            reward = -1                         #走一步給小負反饋,走的步數越小,負反饋越小

        done = (state == goal)  

        if done:
            break

        current_q = q_table[state][action]      #根據公式更新qtable
        q_table[state][action] += rate * (reward + factor * max(q_table[next_state])  - current_q) 

        state = next_state
        path.append(state)

    if i % 100 == 0:                            #每100次看結果

        for h in range(H):
            for w in range(W):
                if maze[w,h]==-1:
                    pixels[w, h] = (0, 0, 0)
                else:
                    pixels[w, h] = (255, 255, 255)

        for x, y in path:
            pixels[x, y] = (0, 0, 255)

        pixels[start] = (255, 0, 0)
        pixels[goal] = (0, 255, 0)

        plt.clf()                           # 清除當前圖形
        plt.imshow(img)
        plt.pause(0.1)                      # 暫停0.1秒,顯示動態效果

        buf = io.BytesIO()
        plt.savefig(buf, format='png')      # 保存圖像到內存中
        buf.seek(0)                         # 將文件指針移動到文件開頭
        images.append(imageio.imread(buf))  # 從內存中讀取圖像並添加到列表中

plt.show()
imageio.mimsave('result.gif', images, fps=3)  # 保存爲 GIF 圖像,幀率爲3

結果如下:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章