强化学习笔记(二)---- 策略迭代算法

强化学习有两种常见迭代训练算法:策略迭代算法和值迭代算法。本文中主要讲述策略迭代算法。

先从一个简答的问题开始,下图为一个四方格子,每个位置的状态空间分别为{1, 2, 3, 4}, 其中 3 的位置是个陷阱, 4的位置有个金币。有一个机器人从状态1的位置开始寻找金币。落入陷阱的回报为-1,找到金币的回报为1,在其他位置间移动回报为0,可选的动作空间为{上,下,左,右}, 通过这个简单的问题,来学习强化学习的学习原理。

这里写图片描述

强化学习的学习过程,个人理解就是通过不断的尝试,去更新每个状态的值函数(每个状态的值代表了当前状态的优劣,如果状态值很大,从其他状态选择一个动作,转移到该状态便是一个正确的选择),然后通过更新后的值函数去动态的调整策略,在调整策略后,又去更新值函数,不断的迭代更新,最后训练完成一个满足要求的策略。在这个过程中,抽象出两个主要的过程,第一个叫策略评估,第二个叫策略改善。

针对上面给出的简单问题,先说明一些简单的概念:

每个状态的值函数:

代表机器人处于该状态时的优劣值。

针对问题的当前策略:

代表机器人处于某状态时,选择的下一步动作。对于选择的下一步动作,可以是确定式的,比如当机器人处于1位置的时候,确定的只选择往右走。也可以是概率式的,可以0.5的概率选择往右走, 0.5的概率选择往下走。当然确定式策略选择是概率式的策略选择的一种特例。下文中采用确定式策略进行描述

策略评估:

策略评估就是通过某种方式,计算状态空间中每个状态的值函数。由于状态空间之间存在很多转移关系,要直接计算某个状态的值函数,是很困难的,一般采用
迭代方法。

策略改善:

对策略的改善,即通过当前拥有的信息,对当前策略进行优化,修改当前策略。

############################## 策略评估的过程

初始化的策略和值函数。

这里写图片描述

对于这个简单的例子,通过一步计算便得到了稳定的值函数,但是对于大多数的问题,都需要通过多步的迭代,才能得到稳定的值函数。

############################## 策略改善的过程

对于这个简单的例子,采用贪心的方式对策略进行改善,通过上一步策略评估过程计算出的稳定的值函数,让每个状态在选择下一步动作的时候,选择使动作收益最大的动作。
这里写图片描述

总结

强化学习策略迭代算法的过程就是不断的重复 策略评估 和 策略改善的过程,直到整个策略收敛(值函数和策略不再发生大的变化)

gym的样例代码展示

上述的简单示例,简单描述了强化学习的策略迭代算法的过程,下面将问题搞复杂一点,通过对该问题进行编程,加深对策略迭代算法的理解。新问题的空间状态如下图所示。

这里写图片描述

该图的状态空间由下置上,从左到右,分别为1 – 36

import numpy as np
import random
from gym import spaces
import gym
from gym.envs.classic_control import rendering

#模拟环境类
class GridWorldEnv(gym.Env):
    #相关的全局配置
    metadata = {
        'render.modes':['human', 'rgb_array'],
        'video.frames_per_second': 2
    }

    def __init__(self):
        self.states = [i for i in range(1, 37)] #初始化状态
        self.terminate_states = [3, 7, 11, 15, 19, 20, 23, 30,  33, 34] #终结态
        self.actions = ['up', 'down', 'left', 'right'] #动作空间

        self.v_states = dict() #状态的值空间
        for state in self.states:
            self.v_states[state] = 0.0

        for state in self.terminate_states: #先将所有陷阱和黄金的值函数初始化为-1.0
            self.v_states[state] = -1.0

        self.v_states[34] = 1.0  #黄金的位置值函数初始化为 1

        self.initStateAction() #初始化每个状态的可行动作空间
        self.initStatePolicyAction() #随机初始化当前策略

        self.gamma = 0.8 #计算值函数用的折扣因子
        self.viewer = None #视图对象
        self.current_state = None #当前状态
        return

    def translateStateToRowCol(self, state):
        """
        将状态转化为行列座标返回
        """
        row = (state - 1) // 6
        col = (state - 1) %  6
        return row, col

    def translateRowColToState(self, row, col):
        """
        将行列座标转化为状态值
        """
        return row * 6 + col + 1

    def actionRowCol(self, row, col, action):
        """
        对行列座标执行动作action并返回座标
        """
        if action == "up":
            row = row - 1
        if action == "down":
            row = row + 1
        if action == "left":
            col = col - 1
        if action == "right":
            col = col + 1
        return row, col

    def canUp(self, row, col):
        row = row - 1
        return 0 <= row <= 5

    def canDown(self, row, col):
        row = row + 1
        return 0 <= row <= 5

    def canLeft(self, row, col):
        col = col - 1
        return 0 <= col <= 5

    def canRight(self, row, col):
        col = col + 1
        return 0 <= col <= 5

    def initStateAction(self):
        """
        初始化每个状态可行动作空间
        """
        self.states_actions = dict()
        for state in self.states:
            self.states_actions[state] = []
            if state in self.terminate_states:
                continue
            row, col = self.translateStateToRowCol(state)
            if self.canUp(row, col):
                self.states_actions[state].append("up")
            if self.canDown(row, col):
                self.states_actions[state].append("down")
            if self.canLeft(row, col):
                self.states_actions[state].append('left')
            if self.canRight(row, col):
                self.states_actions[state].append('right')
        return


    def initStatePolicyAction(self):
        """
        初始化每个状态的当前策略动作
        """
        self.states_policy_action = dict()
        for state in self.states:
            if state in self.terminate_states:
                self.states_policy_action[state] = None
            else:
                self.states_policy_action[state] = random.sample(self.states_actions[state], 1)[0]
        return


    def seed(self, seed = None):
        random.seed(seed)
        return [seed]

    def reset(self):
        """
        重置原始状态
        """
        self.current_state = random.sample(self.states, 1)[0]

    def step(self, action):
        """
        动作迭代函数
        """
        cur_state = self.current_state
        if cur_state in self.terminate_states:
            return cur_state, 0, True, {}
        row, col = self.translateStateToRowCol(cur_state)
        n_row, n_col = self.actionRowCol(row, col, action)
        next_state = self.translateRowColToState(n_row, n_col)
        self.current_state = next_state
        if next_state in self.terminate_states:
            return next_state, 0, True, {}
        else:
            return next_state, 0, False, {}

    def policy_evaluate(self):
        """
        策略评估过程 
        """
        error = 0.000001 #误差率
        for _ in range(1000):
            max_error = 0.0 #初始化最大误差
            for state in self.states:
                if state in self.terminate_states:
                    continue
                action = self.states_policy_action[state]
                self.current_state = state
                next_state, reward, isTerminate, info = self.step(action)
                old_value = self.v_states[state]
                self.v_states[state] = reward + self.gamma * self.v_states[next_state]
                abs_error = abs(self.v_states[state] - old_value)
                max_error = abs_error if abs_error > max_error else max_error #更新最大值
            if max_error < error:
                break


    def policy_improve(self):
        """
        根据策略评估的结果,进行策略更新,并返回每个状态的当前策略是否发生了变化
        """
        changed = False
        for state in self.states:
            if state in self.terminate_states:
                continue
            max_value_action = self.states_actions[state][0] #当前最大值行为
            max_value = -1000000000000.0 #当前最大回报 
            for action in self.states_actions[state]:
                self.current_state = state
                next_state, reward, isTerminate, info = self.step(action)
                q_reward = reward + self.gamma * self.v_states[next_state]
                if q_reward > max_value:
                    max_value_action = action
                    max_value = q_reward
            if self.states_policy_action[state] != max_value_action:
                changed = True
            self.states_policy_action[state] = max_value_action
        return changed





    def createGrids(self):
        """
        创建网格
        """
        start_x = 40
        start_y = 40
        line_length = 40
        for state in self.states:
            row, col = self.translateStateToRowCol(state)
            x = start_x + col * line_length
            y = start_y + row * line_length
            line = rendering.Line((x, y), (x + line_length, y))
            line.set_color(0, 0, 0)
            self.viewer.add_onetime(line)
            line = rendering.Line((x, y), (x, y  + line_length))
            line.set_color(0, 0, 0)
            self.viewer.add_onetime(line)
            line = rendering.Line((x + line_length, y), (x + line_length, y  + line_length))
            line.set_color(0, 0, 0)
            self.viewer.add_onetime(line)
            line = rendering.Line((x, y + line_length), (x + line_length, y  + line_length))
            line.set_color(0, 0, 0)
            self.viewer.add_onetime(line)

    def createTraps(self):
        """
        创建陷阱,将黄金的位置也先绘制成陷阱,后面覆盖画成黄金
        """
        start_x = 40 
        start_y = 40
        line_length = 40
        for state in self.terminate_states:
            row, col = self.translateStateToRowCol(state)
            trap = rendering.make_circle(20)
            trans = rendering.Transform()
            trap.add_attr(trans)
            trap.set_color(0, 0, 0)
            trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
            self.viewer.add_onetime(trap)

    def createGold(self):
        """
        创建黄金
        """
        start_x = 40 
        start_y = 40
        line_length = 40
        state = 34
        row, col = self.translateStateToRowCol(state)
        gold = rendering.make_circle(20)
        trans = rendering.Transform()
        gold.add_attr(trans)
        gold.set_color(1, 0.9, 0)
        trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
        self.viewer.add_onetime(gold)

    def createRobot(self):
        """
        创建机器人
        """
        start_x = 40 
        start_y = 40
        line_length = 40
        row, col = self.translateStateToRowCol(self.current_state)
        robot = rendering.make_circle(15)
        trans = rendering.Transform()
        robot.add_attr(trans)
        robot.set_color(1, 0, 1)
        trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
        self.viewer.add_onetime(robot)

    def render(self, mode="human", close=False):
        """
        渲染整个场景
        """
        #关闭视图
        if close:
            if self.viewer is not None:
                self.viewer.close()
                self.viewer = None

        #视图的大小
        screen_width = 320
        screen_height = 320


        if self.viewer is None:
            self.viewer = rendering.Viewer(screen_width, screen_height)

        #创建网格
        self.createGrids()
        #创建陷阱
        self.createTraps()
        #创建黄金
        self.createGold()
        #创建机器人
        self.createRobot()
        return self.viewer.render(return_rgb_array= mode == 'rgb_array')
注册环境模拟类到gym
from gym.envs.registration import register
try:
    register(id = "GridWorld-v3", entry_point=GridWorldEnv, max_episode_steps = 200, reward_threshold=100.0)
except:
    pass
进行策略迭代算法的过程和模拟动画的代码
from time import sleep
env = gym.make('GridWorld-v3')
env.reset()

#策略评估和策略改善 
not_changed_count = 0
for _ in range(10000):
    env.env.policy_evaluate()
    changed = env.env.policy_improve()
    if changed:
        not_changed_count = 0
    else:
        not_changed_count += 1
    if not_changed_count == 10: #超过10次策略没有再更新,说明策略已经稳定了
        break


#观察env到底是个什么东西的打印信息。
print(isinstance(env, GridWorldEnv))
print(type(env))
print(env.__dict__)
print(isinstance(env.env, GridWorldEnv))

env.reset()

for _ in range(1000):
    env.render()
    if env.env.states_policy_action[env.env.current_state] is not None:
        observation,reward,done,info = env.step(env.env.states_policy_action[env.env.current_state])
    else:
        done = True
    print(_)
    if done:
        sleep(0.5)
        env.render()
        env.reset()
        print("reset")
    sleep(0.5)
env.close()
动画效果

这里写图片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章