【Python】Q-Learning處理CartPole-v1

上一篇配置成功gym環境後,就可以利用該環境做強化學習仿真了。

這裏首先用之前學習過的qlearning來處理CartPole-v1模型。

CartPole-v1是一個倒立擺模型,目標是通過左右移動滑塊保證倒立杆能夠儘可能長時間倒立,最長步驟爲500步。

模型控制量是左0、右1兩個。

模型狀態量爲下面四個:

Num Observation Min Max
0 Cart Position -4.8 4.8
1 Cart Velocity -Inf Inf
2 Pole Angle -0.418rad 0.418rad
3 Pole Angular Velocity -Inf Inf

由於要用qtable,但是狀態量是連續的,所以我們要先對狀態做離散化處理,對應state_dig函數。

然後就是按照qlearning公式迭代即可。

這裏在選控制量的時候用了ε-greedy策略,即根據迭代次數,逐步更相信模型的結果而不是隨機的結果。

qlearning走迷宮當時的策略是有10%的概率用隨機的控制量,ε-greedy策略相對更合理一些。

代碼如下:

import gym
import random
import numpy as np

Num = 10
rate = 0.5
factor = 0.9

p_bound = np.linspace(-2.4,2.4,Num-1)
v_bound = np.linspace(-3,3,Num-1)
ang_bound = np.linspace(-0.5,0.5,Num-1)
angv_bound = np.linspace(-2.0,2.0,Num-1)

def state_dig(state):                   #離散化
    p,v,ang,angv = state
    digital_state = (np.digitize(p, p_bound),
            np.digitize(v, v_bound),
            np.digitize(ang, ang_bound), 
            np.digitize(angv, angv_bound))
    return digital_state

if __name__ == '__main__':

    env = gym.make('CartPole-v1')

    action_space_dim = env.action_space.n  
    q_table = np.zeros((Num,Num,Num,Num, action_space_dim))

    for i in range(3000):
        state = env.reset()
        digital_state = state_dig(state)
                
        step = 0
        while True:
            if i%10==0:
                env.render()
            
            step +=1
            epsi = 1.0 / (i + 1)
            if random.random() < epsi:
                action = random.randrange(action_space_dim)
            else:
                action = np.argmax(q_table[digital_state])

            next_state, reward, done, _ = env.step(action)
            next_digital_state = state_dig(next_state)
  
            if done: 
                if step < 400:
                    reward = -1  
                else:   
                    reward = 1
            else:
                reward = 0

            current_q = q_table[digital_state][action]      #根據公式更新qtable
            q_table[digital_state][action] += rate * (reward + factor * max(q_table[next_digital_state])  - current_q) 

            digital_state = next_digital_state

            if done:
                print(step)
                break

最終結果基本都能維持到500步左右,不過即使到500後,隨着模型迭代,狀態也可能不穩定。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章