强化学习—DQN:不讲前世,就论今生

{"type":"doc","content":[{"type":"heading","attrs":{"align":null,"level":1},"content":[{"type":"text","text":"前言","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 相信小可爱们点进这篇文章,要么是对强化学习(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"Reinforcement learning-RL","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":")有一定的了解,要么是想要了解强化学习的魅力所在,要么是了解了很多基础知识,但是不知道代码如何写。今天我就以最经典和基础的算法(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"DQN","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":")带大家一探强化学习的强大基因,不讲前世(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"不讲解公式推导","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"),只讨论今生(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"通俗语言+图示讲解DQN算法流程,以及代码如何实现","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":")。","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"heading","attrs":{"align":null,"level":1},"content":[{"type":"text","text":"预备知识","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 如果小可爱们只想了解DQN算法的流程,那么跟着我的步伐,一点一点看下去就可以。如果你想要使用代码实现算法并亲眼看到它的神奇之处,也可以在本文中找到答案。","attrs":{}}]},{"type":"blockquote","content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"本文代码实现基于tensorflow1.8,想看懂代码的,建议熟悉tf的基础知识","attrs":{}}]}],"attrs":{}},{"type":"horizontalrule","attrs":{}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 下面进入正题:首先想想为什么叫它强化学习呢?强化,强化,再强化??? 哦!想起来了,好像和我们日常学习中的强化练习、《英语强化习题集》有点相似,考过研的小可爱们都知道张宇老师还有强化课程和相关习题册。哈哈,是不是想起来在高三或考研期间被各种强化学习资料支配的恐惧呢?不用担心,在本文中没有强化资料。我们再思考一下,小可爱们当初为了考取理想的分数,不断做题强化,是不是就是为了提高做题速度、逻辑思考能力和正确率,从而提高成绩?通过不断做题强化,我们学到了很多知识,提高了分数,那这些指标是不是我们的","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"收益","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"?在进行强化训练整个过程中,我们是不是一次次地和试题进行较量,拼个你死我活,接着试题答案要么给我们当头一击吃个0蛋、要么赏给我们颗糖吃吃,顿时心里美滋滋。试题答案给了我们反馈,同时我们在接收反馈之后,反思自己,找到自己的不足并纠正,以使在以后面对这些题时,可以主动避开错误的解法,尽可能拿更高的分数。慢慢地,我们找到了做题套路,解题能力得到了强化提高。","attrs":{}}]},{"type":"heading","attrs":{"align":null,"level":2},"content":[{"type":"text","text":"强化学习基本思想理解 ","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 强化学习的基本思想也是一样的道理。下面是强化学习框架:","attrs":{}}]},{"type":"image","attrs":{"src":"https://static001.geekbang.org/infoq/a5/a5b03c73740d3e63e56822f147effb32.jpeg","alt":null,"title":"强化学习框架图","style":[{"key":"width","value":"100%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"现在把图中的","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"Agent(智能体)","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"当做小可爱自己,","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"Environment(环境)","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"当做大量的试题集(包含大量的不同的试题),","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"State s(状态)","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"当做当前时刻小可爱正在做的某一试题,","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"Action(动作)","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"当做小可爱解题过程和步骤,","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"Reward(奖励)","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"当做试题答案给小可爱的作答打分,","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"State s'","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"当做在试题集中小可爱当前所做题的下一个试题【这道题小可爱做完了,也看到得了多少分,就会去做下一道试题,是不是就意味着状态从当前状态转移到了下一个状态】。小可爱们按照这种思路再次读一读上面一段话,对照着强化学习框架,是不是就明白了智能体与环境交互的过程,也明白了强化学习的基本思想了呢?","attrs":{}}]},{"type":"heading","attrs":{"align":null,"level":2},"content":[{"type":"text","text":"强化学习中的探索和利用理解","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 如果小可爱们在读了上面的讲解之后,已经对强化学习基本思想有了进一步的理解,那么恭喜小可爱顺利通过第一关。如果还没有更加明白,不要慌!跟着我,在学习这一部分探索和利用问题的同时,你也可以进一步理解和巩固强化学习基本思想。","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 在强化学习算法中,当智能体与环境进行交互时有一个探索的概念和一个利用的概念。顾名思义,","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"探索","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"的意思就是智能体以前在某些状态没有做过这种动作【状态和动作是一一对应的,在某一状态下,做出某一动作】,它要尝试解锁一些新动作,说不定这个被解锁的动作就可以帮助智能体得到更高的分。","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"利用","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"的意思就是智能体从以前做过的所有动作中,选择一个可以帮助他获得更高分的动作。也就是利用以前的经历去选择出最优动作。","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 下面以小可爱做","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"一道数学题","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"为例来说明这个问题,顺便加深对强化学习基本思想的理解。这是一道非常难的一道数学题,满分100分,有","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"上百个步骤","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":",同时也有","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"很多种解法","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":",这对于第一次接触这种题的小可爱来说,太难了。但是这并难不倒聪明的小可爱们,花了一天时间,大概写了100个步骤,终于把这道题做出来了。在对比了答案之后,事情不尽人意,只做对了几步,得了一点点步骤分。一看只得了这么一点分,小可爱们归纳整理做这道题的经验并反思。又开始重新做这道题,并试图在某些步骤尝试不同的解法【这是不是","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"探索","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"呢?】,最终花费的时间比第一次少了许多,并且得分也比以前高了。这让小可爱们甚是兴奋,更加自信了。小可爱们试图又重新做这道题,基于前几次的做题经验【这里是不是包含了","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"利用","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"呢?】以及","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"探索","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"使用不同的方法,发现有时候在某些步骤上用比较复杂的方法,但是却为后面的解题大大减少了计算量,从而减小了错误概率。最后,尽管用时稍微长些,但是得分更高了。就这样一次又一次地重新做这道题,小可爱们慢慢地摸索到了这道题的最佳做题套路,在","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"用时和得分之间寻找到了平衡","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":",使得最终所摸索到的关于这道题的做题技巧","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"性价比","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"最高。","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 看完上面一段话,小可爱们是不是对于强化学习中的探索和利用以及基本思想有了更加深刻地理解了呢?恭喜你,小可爱,又完美通过一关。","attrs":{}}]},{"type":"heading","attrs":{"align":null,"level":1},"content":[{"type":"text","text":"图解DQN算法","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 这个部分,将使用图解和文字结合的方法对DQN算法进行剖析。先上个图:","attrs":{}}]},{"type":"image","attrs":{"src":"https://static001.geekbang.org/infoq/c1/c18cbca97b5e299a7892a1d66cb46262.jpeg","alt":null,"title":"DQN算法结构图","style":[{"key":"width","value":"100%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"如果看不懂上面的图,没关系。下面该部分将再次用图解的方式对DQN算法流程进行更细致的剖析,上图整活:","attrs":{}}]},{"type":"image","attrs":{"src":"https://static001.geekbang.org/infoq/75/7584eb830d8c94d6712f88a32823dedb.jpeg","alt":null,"title":"DQN算法深入剖析图","style":[{"key":"width","value":"100%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"高能预警!","attrs":{}}]},{"type":"blockquote","content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"注意上图中的eval网络和target网络结构是完全一样的,但是权重参数不是实时同步的,target网络权重参数更新是落后于eval网络的权重参数更新的。","attrs":{}}]}],"attrs":{}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 从以上两个图中我们可以知道,DQN算法是分为两个部分的。","attrs":{}}]},{"type":"bulletedlist","content":[{"type":"listitem","attrs":{"listStyle":null},"content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"第一部分—经验收集阶段:智能体想要从幼儿园进化到大学水平,学习成绩更好那就必须要有经验啊。那这一部分就是智能体与环境进行交互收集经验的阶段。注意:这一部分只会用到eval网络,智能体从环境中观察到状态","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"s","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"输入到eval网络,网络会输出每一个动作对应的Q值。智能体在选择动作时一般使用","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"ε贪心策略","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"来选择动作—【在整个训练初始阶段,大概率随机选择动作(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"探索","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"),随着训练轮数的增加,小概率随机选择动作,大概率选择最大Q值对应的动作(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"利用","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":")】。在选择动作后,会作用到环境,环境会相应的反馈给智能体一个奖励","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"r","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":",同时环境从当前状态","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"s","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"转移到","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"s'","attrs":{}},{"type":"text","text":"。","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"每进行一次交互,都会把(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"s,a,r,s'","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":")保存到经验池中。","attrs":{}}]}]},{"type":"listitem","attrs":{"listStyle":null},"content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"第二部分—训练学习阶段:在交互次数达到预先设定的值后,就进入到了","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"训练学习","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"阶段。这个阶段是智能体真正利用以前的经验来进行反思提高的(训练)。智能体会从经验池中随机抽取","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"batch-size组","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"的数据,把这组数据中的","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"s","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"批量输入到eval网络,","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"s'","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"批量输入到target网络中,得到","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"该组数据中每一条数据","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"对应的q_eval和q_target。然后计算","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"该组数据","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"的损失,然后进行反向传播进行训练。训练结束后,再次进入到第一部分。","attrs":{}}]}]}],"attrs":{}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"以上两个部分交替进行,直到达到预先设定的训练轮数。相信各位小可爱在看了本部分的介绍,明白了DQN算法的具体流程,恭喜你,又一次成功通关。","attrs":{}}]},{"type":"heading","attrs":{"align":null,"level":1},"content":[{"type":"text","text":"代码实现—基于tensorflow1版本","attrs":{}}]},{"type":"codeblock","attrs":{"lang":"python"},"content":[{"type":"text","text":"\nimport numpy as np\nimport tensorflow as tf\n\nnp.random.seed(1)\ntf.set_random_seed(1)\n\n\n# Deep Q Network off-policy\nclass DeepQNetwork:\n def __init__(\n self,\n n_actions,\n n_features,\n learning_rate=0.01,\n reward_decay=0.9,\n e_greedy=0.9,\n replace_target_iter=300,\n memory_size=500,\n batch_size=32,\n e_greedy_increment=None,\n output_graph=False,\n ):\n self.n_actions = n_actions\n self.n_features = n_features\n self.lr = learning_rate\n self.gamma = reward_decay\n self.epsilon_max = e_greedy\n self.replace_target_iter = replace_target_iter\n self.memory_size = memory_size\n self.batch_size = batch_size\n self.epsilon_increment = e_greedy_increment\n self.epsilon = 0 if e_greedy_increment is not None else self.epsilon_max\n\n # total learning step\n self.learn_step_counter = 0\n\n # initialize zero memory [s, a, r, s_]\n self.memory = np.zeros((self.memory_size, n_features * 2 + 2))\n\n # consist of [target_net, evaluate_net]\n self._build_net()\n\n t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')\n e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net')\n\n with tf.variable_scope('hard_replacement'):\n self.target_replace_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]\n\n self.sess = tf.Session()\n\n if output_graph:\n # $ tensorboard --logdir=logs\n tf.summary.FileWriter(\"logs/\", self.sess.graph)\n\n self.sess.run(tf.global_variables_initializer())\n self.cost_his = []\n\n def _build_net(self):\n # ------------------ all inputs ------------------------\n self.s = tf.placeholder(tf.float32, [None, self.n_features], name='s') # input State\n self.s_ = tf.placeholder(tf.float32, [None, self.n_features], name='s_') # input Next State\n self.r = tf.placeholder(tf.float32, [None, ], name='r') # input Reward\n self.a = tf.placeholder(tf.int32, [None, ], name='a') # input Action\n\n w_initializer, b_initializer = tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1)\n\n # ------------------ build evaluate_net ------------------\n with tf.variable_scope('eval_net'):\n e1 = tf.layers.dense(self.s, 20, tf.nn.relu, kernel_initializer=w_initializer,\n bias_initializer=b_initializer, name='e1')\n self.q_eval = tf.layers.dense(e1, self.n_actions, kernel_initializer=w_initializer,\n bias_initializer=b_initializer, name='q')\n\n # ------------------ build target_net ------------------\n with tf.variable_scope('target_net'):\n t1 = tf.layers.dense(self.s_, 20, tf.nn.relu, kernel_initializer=w_initializer,\n bias_initializer=b_initializer, name='t1')\n self.q_next = tf.layers.dense(t1, self.n_actions, kernel_initializer=w_initializer,\n bias_initializer=b_initializer, name='t2')\n\n with tf.variable_scope('q_target'):\n q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_') # shape=(None, )\n self.q_target = tf.stop_gradient(q_target)\n with tf.variable_scope('q_eval'):\n a_indices = tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1)\n self.q_eval_wrt_a = tf.gather_nd(params=self.q_eval, indices=a_indices) # shape=(None, )\n with tf.variable_scope('loss'):\n self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))\n with tf.variable_scope('train'):\n self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)\n\n def store_transition(self, s, a, r, s_):\n if not hasattr(self, 'memory_counter'):\n self.memory_counter = 0\n transition = np.hstack((s, [a, r], s_))\n # replace the old memory with new memory\n index = self.memory_counter % self.memory_size\n self.memory[index, :] = transition\n self.memory_counter += 1\n\n def choose_action(self, observation):\n # to have batch dimension when feed into tf placeholder\n observation = observation[np.newaxis, :]\n\n if np.random.uniform() < self.epsilon:\n # forward feed the observation and get q value for every actions\n actions_value = self.sess.run(self.q_eval, feed_dict={self.s: observation})\n action = np.argmax(actions_value)\n else:\n action = np.random.randint(0, self.n_actions)\n return action\n\n def learn(self):\n # check to replace target parameters\n if self.learn_step_counter % self.replace_target_iter == 0:\n self.sess.run(self.target_replace_op)\n print('\\ntarget_params_replaced\\n')\n\n # sample batch memory from all memory\n if self.memory_counter > self.memory_size:\n sample_index = np.random.choice(self.memory_size, size=self.batch_size)\n else:\n sample_index = np.random.choice(self.memory_counter, size=self.batch_size)\n batch_memory = self.memory[sample_index, :]\n\n _, cost = self.sess.run(\n [self._train_op, self.loss],\n feed_dict={\n self.s: batch_memory[:, :self.n_features],\n self.a: batch_memory[:, self.n_features],\n self.r: batch_memory[:, self.n_features + 1],\n self.s_: batch_memory[:, -self.n_features:],\n })\n\n self.cost_his.append(cost)\n\n # increasing epsilon\n self.epsilon = self.epsilon + self.epsilon_increment if self.epsilon < self.epsilon_max else self.epsilon_max\n self.learn_step_counter += 1\n\n def plot_cost(self):\n import matplotlib.pyplot as plt\n plt.plot(np.arange(len(self.cost_his)), self.cost_his)\n plt.ylabel('Cost')\n plt.xlabel('training steps')\n plt.show()\n","attrs":{}}]},{"type":"heading","attrs":{"align":null,"level":2}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}}]}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章