強化學習—DQN:不講前世,就論今生

{"type":"doc","content":[{"type":"heading","attrs":{"align":null,"level":1},"content":[{"type":"text","text":"前言","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 相信小可愛們點進這篇文章,要麼是對強化學習(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"Reinforcement learning-RL","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":")有一定的瞭解,要麼是想要了解強化學習的魅力所在,要麼是瞭解了很多基礎知識,但是不知道代碼如何寫。今天我就以最經典和基礎的算法(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"DQN","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":")帶大家一探強化學習的強大基因,不講前世(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"不講解公式推導","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"),只討論今生(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"通俗語言+圖示講解DQN算法流程,以及代碼如何實現","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":")。","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"heading","attrs":{"align":null,"level":1},"content":[{"type":"text","text":"預備知識","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 如果小可愛們只想瞭解DQN算法的流程,那麼跟着我的步伐,一點一點看下去就可以。如果你想要使用代碼實現算法並親眼看到它的神奇之處,也可以在本文中找到答案。","attrs":{}}]},{"type":"blockquote","content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"本文代碼實現基於tensorflow1.8,想看懂代碼的,建議熟悉tf的基礎知識","attrs":{}}]}],"attrs":{}},{"type":"horizontalrule","attrs":{}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 下面進入正題:首先想想爲什麼叫它強化學習呢?強化,強化,再強化??? 哦!想起來了,好像和我們日常學習中的強化練習、《英語強化習題集》有點相似,考過研的小可愛們都知道張宇老師還有強化課程和相關習題冊。哈哈,是不是想起來在高三或考研期間被各種強化學習資料支配的恐懼呢?不用擔心,在本文中沒有強化資料。我們再思考一下,小可愛們當初爲了考取理想的分數,不斷做題強化,是不是就是爲了提高做題速度、邏輯思考能力和正確率,從而提高成績?通過不斷做題強化,我們學到了很多知識,提高了分數,那這些指標是不是我們的","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"收益","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"?在進行強化訓練整個過程中,我們是不是一次次地和試題進行較量,拼個你死我活,接着試題答案要麼給我們當頭一擊喫個0蛋、要麼賞給我們顆糖喫喫,頓時心裏美滋滋。試題答案給了我們反饋,同時我們在接收反饋之後,反思自己,找到自己的不足並糾正,以使在以後面對這些題時,可以主動避開錯誤的解法,儘可能拿更高的分數。慢慢地,我們找到了做題套路,解題能力得到了強化提高。","attrs":{}}]},{"type":"heading","attrs":{"align":null,"level":2},"content":[{"type":"text","text":"強化學習基本思想理解 ","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 強化學習的基本思想也是一樣的道理。下面是強化學習框架:","attrs":{}}]},{"type":"image","attrs":{"src":"https://static001.geekbang.org/infoq/a5/a5b03c73740d3e63e56822f147effb32.jpeg","alt":null,"title":"強化學習框架圖","style":[{"key":"width","value":"100%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"現在把圖中的","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"Agent(智能體)","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"當做小可愛自己,","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"Environment(環境)","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"當做大量的試題集(包含大量的不同的試題),","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"State s(狀態)","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"當做當前時刻小可愛正在做的某一試題,","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"Action(動作)","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"當做小可愛解題過程和步驟,","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"Reward(獎勵)","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"當做試題答案給小可愛的作答打分,","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"State s'","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"當做在試題集中小可愛當前所做題的下一個試題【這道題小可愛做完了,也看到得了多少分,就會去做下一道試題,是不是就意味着狀態從當前狀態轉移到了下一個狀態】。小可愛們按照這種思路再次讀一讀上面一段話,對照着強化學習框架,是不是就明白了智能體與環境交互的過程,也明白了強化學習的基本思想了呢?","attrs":{}}]},{"type":"heading","attrs":{"align":null,"level":2},"content":[{"type":"text","text":"強化學習中的探索和利用理解","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 如果小可愛們在讀了上面的講解之後,已經對強化學習基本思想有了進一步的理解,那麼恭喜小可愛順利通過第一關。如果還沒有更加明白,不要慌!跟着我,在學習這一部分探索和利用問題的同時,你也可以進一步理解和鞏固強化學習基本思想。","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 在強化學習算法中,當智能體與環境進行交互時有一個探索的概念和一個利用的概念。顧名思義,","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"探索","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"的意思就是智能體以前在某些狀態沒有做過這種動作【狀態和動作是一一對應的,在某一狀態下,做出某一動作】,它要嘗試解鎖一些新動作,說不定這個被解鎖的動作就可以幫助智能體得到更高的分。","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"利用","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"的意思就是智能體從以前做過的所有動作中,選擇一個可以幫助他獲得更高分的動作。也就是利用以前的經歷去選擇出最優動作。","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 下面以小可愛做","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"一道數學題","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"爲例來說明這個問題,順便加深對強化學習基本思想的理解。這是一道非常難的一道數學題,滿分100分,有","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"上百個步驟","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":",同時也有","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"很多種解法","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":",這對於第一次接觸這種題的小可愛來說,太難了。但是這並難不倒聰明的小可愛們,花了一天時間,大概寫了100個步驟,終於把這道題做出來了。在對比了答案之後,事情不盡人意,只做對了幾步,得了一點點步驟分。一看只得了這麼一點分,小可愛們歸納整理做這道題的經驗並反思。又開始重新做這道題,並試圖在某些步驟嘗試不同的解法【這是不是","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"探索","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"呢?】,最終花費的時間比第一次少了許多,並且得分也比以前高了。這讓小可愛們甚是興奮,更加自信了。小可愛們試圖又重新做這道題,基於前幾次的做題經驗【這裏是不是包含了","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"利用","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"呢?】以及","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"探索","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"使用不同的方法,發現有時候在某些步驟上用比較複雜的方法,但是卻爲後面的解題大大減少了計算量,從而減小了錯誤概率。最後,儘管用時稍微長些,但是得分更高了。就這樣一次又一次地重新做這道題,小可愛們慢慢地摸索到了這道題的最佳做題套路,在","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"用時和得分之間尋找到了平衡","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":",使得最終所摸索到的關於這道題的做題技巧","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"性價比","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"最高。","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 看完上面一段話,小可愛們是不是對於強化學習中的探索和利用以及基本思想有了更加深刻地理解了呢?恭喜你,小可愛,又完美通過一關。","attrs":{}}]},{"type":"heading","attrs":{"align":null,"level":1},"content":[{"type":"text","text":"圖解DQN算法","attrs":{}}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 這個部分,將使用圖解和文字結合的方法對DQN算法進行剖析。先上個圖:","attrs":{}}]},{"type":"image","attrs":{"src":"https://static001.geekbang.org/infoq/c1/c18cbca97b5e299a7892a1d66cb46262.jpeg","alt":null,"title":"DQN算法結構圖","style":[{"key":"width","value":"100%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"如果看不懂上面的圖,沒關係。下面該部分將再次用圖解的方式對DQN算法流程進行更細緻的剖析,上圖整活:","attrs":{}}]},{"type":"image","attrs":{"src":"https://static001.geekbang.org/infoq/75/7584eb830d8c94d6712f88a32823dedb.jpeg","alt":null,"title":"DQN算法深入剖析圖","style":[{"key":"width","value":"100%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"高能預警!","attrs":{}}]},{"type":"blockquote","content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"注意上圖中的eval網絡和target網絡結構是完全一樣的,但是權重參數不是實時同步的,target網絡權重參數更新是落後於eval網絡的權重參數更新的。","attrs":{}}]}],"attrs":{}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":" 從以上兩個圖中我們可以知道,DQN算法是分爲兩個部分的。","attrs":{}}]},{"type":"bulletedlist","content":[{"type":"listitem","attrs":{"listStyle":null},"content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"第一部分—經驗收集階段:智能體想要從幼兒園進化到大學水平,學習成績更好那就必須要有經驗啊。那這一部分就是智能體與環境進行交互收集經驗的階段。注意:這一部分只會用到eval網絡,智能體從環境中觀察到狀態","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"s","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"輸入到eval網絡,網絡會輸出每一個動作對應的Q值。智能體在選擇動作時一般使用","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"ε貪心策略","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"來選擇動作—【在整個訓練初始階段,大概率隨機選擇動作(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"探索","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"),隨着訓練輪數的增加,小概率隨機選擇動作,大概率選擇最大Q值對應的動作(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"利用","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":")】。在選擇動作後,會作用到環境,環境會相應的反饋給智能體一個獎勵","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"r","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":",同時環境從當前狀態","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"s","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"轉移到","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"s'","attrs":{}},{"type":"text","text":"。","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"每進行一次交互,都會把(","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"s,a,r,s'","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":")保存到經驗池中。","attrs":{}}]}]},{"type":"listitem","attrs":{"listStyle":null},"content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"第二部分—訓練學習階段:在交互次數達到預先設定的值後,就進入到了","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"訓練學習","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"階段。這個階段是智能體真正利用以前的經驗來進行反思提高的(訓練)。智能體會從經驗池中隨機抽取","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"batch-size組","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"的數據,把這組數據中的","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"s","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"批量輸入到eval網絡,","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"s'","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"批量輸入到target網絡中,得到","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"該組數據中每一條數據","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"對應的q_eval和q_target。然後計算","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}},{"type":"strong","attrs":{}}],"text":"該組數據","attrs":{}},{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"的損失,然後進行反向傳播進行訓練。訓練結束後,再次進入到第一部分。","attrs":{}}]}]}],"attrs":{}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":16}}],"text":"以上兩個部分交替進行,直到達到預先設定的訓練輪數。相信各位小可愛在看了本部分的介紹,明白了DQN算法的具體流程,恭喜你,又一次成功通關。","attrs":{}}]},{"type":"heading","attrs":{"align":null,"level":1},"content":[{"type":"text","text":"代碼實現—基於tensorflow1版本","attrs":{}}]},{"type":"codeblock","attrs":{"lang":"python"},"content":[{"type":"text","text":"\nimport numpy as np\nimport tensorflow as tf\n\nnp.random.seed(1)\ntf.set_random_seed(1)\n\n\n# Deep Q Network off-policy\nclass DeepQNetwork:\n def __init__(\n self,\n n_actions,\n n_features,\n learning_rate=0.01,\n reward_decay=0.9,\n e_greedy=0.9,\n replace_target_iter=300,\n memory_size=500,\n batch_size=32,\n e_greedy_increment=None,\n output_graph=False,\n ):\n self.n_actions = n_actions\n self.n_features = n_features\n self.lr = learning_rate\n self.gamma = reward_decay\n self.epsilon_max = e_greedy\n self.replace_target_iter = replace_target_iter\n self.memory_size = memory_size\n self.batch_size = batch_size\n self.epsilon_increment = e_greedy_increment\n self.epsilon = 0 if e_greedy_increment is not None else self.epsilon_max\n\n # total learning step\n self.learn_step_counter = 0\n\n # initialize zero memory [s, a, r, s_]\n self.memory = np.zeros((self.memory_size, n_features * 2 + 2))\n\n # consist of [target_net, evaluate_net]\n self._build_net()\n\n t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')\n e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net')\n\n with tf.variable_scope('hard_replacement'):\n self.target_replace_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]\n\n self.sess = tf.Session()\n\n if output_graph:\n # $ tensorboard --logdir=logs\n tf.summary.FileWriter(\"logs/\", self.sess.graph)\n\n self.sess.run(tf.global_variables_initializer())\n self.cost_his = []\n\n def _build_net(self):\n # ------------------ all inputs ------------------------\n self.s = tf.placeholder(tf.float32, [None, self.n_features], name='s') # input State\n self.s_ = tf.placeholder(tf.float32, [None, self.n_features], name='s_') # input Next State\n self.r = tf.placeholder(tf.float32, [None, ], name='r') # input Reward\n self.a = tf.placeholder(tf.int32, [None, ], name='a') # input Action\n\n w_initializer, b_initializer = tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1)\n\n # ------------------ build evaluate_net ------------------\n with tf.variable_scope('eval_net'):\n e1 = tf.layers.dense(self.s, 20, tf.nn.relu, kernel_initializer=w_initializer,\n bias_initializer=b_initializer, name='e1')\n self.q_eval = tf.layers.dense(e1, self.n_actions, kernel_initializer=w_initializer,\n bias_initializer=b_initializer, name='q')\n\n # ------------------ build target_net ------------------\n with tf.variable_scope('target_net'):\n t1 = tf.layers.dense(self.s_, 20, tf.nn.relu, kernel_initializer=w_initializer,\n bias_initializer=b_initializer, name='t1')\n self.q_next = tf.layers.dense(t1, self.n_actions, kernel_initializer=w_initializer,\n bias_initializer=b_initializer, name='t2')\n\n with tf.variable_scope('q_target'):\n q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_') # shape=(None, )\n self.q_target = tf.stop_gradient(q_target)\n with tf.variable_scope('q_eval'):\n a_indices = tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1)\n self.q_eval_wrt_a = tf.gather_nd(params=self.q_eval, indices=a_indices) # shape=(None, )\n with tf.variable_scope('loss'):\n self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))\n with tf.variable_scope('train'):\n self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)\n\n def store_transition(self, s, a, r, s_):\n if not hasattr(self, 'memory_counter'):\n self.memory_counter = 0\n transition = np.hstack((s, [a, r], s_))\n # replace the old memory with new memory\n index = self.memory_counter % self.memory_size\n self.memory[index, :] = transition\n self.memory_counter += 1\n\n def choose_action(self, observation):\n # to have batch dimension when feed into tf placeholder\n observation = observation[np.newaxis, :]\n\n if np.random.uniform() < self.epsilon:\n # forward feed the observation and get q value for every actions\n actions_value = self.sess.run(self.q_eval, feed_dict={self.s: observation})\n action = np.argmax(actions_value)\n else:\n action = np.random.randint(0, self.n_actions)\n return action\n\n def learn(self):\n # check to replace target parameters\n if self.learn_step_counter % self.replace_target_iter == 0:\n self.sess.run(self.target_replace_op)\n print('\\ntarget_params_replaced\\n')\n\n # sample batch memory from all memory\n if self.memory_counter > self.memory_size:\n sample_index = np.random.choice(self.memory_size, size=self.batch_size)\n else:\n sample_index = np.random.choice(self.memory_counter, size=self.batch_size)\n batch_memory = self.memory[sample_index, :]\n\n _, cost = self.sess.run(\n [self._train_op, self.loss],\n feed_dict={\n self.s: batch_memory[:, :self.n_features],\n self.a: batch_memory[:, self.n_features],\n self.r: batch_memory[:, self.n_features + 1],\n self.s_: batch_memory[:, -self.n_features:],\n })\n\n self.cost_his.append(cost)\n\n # increasing epsilon\n self.epsilon = self.epsilon + self.epsilon_increment if self.epsilon < self.epsilon_max else self.epsilon_max\n self.learn_step_counter += 1\n\n def plot_cost(self):\n import matplotlib.pyplot as plt\n plt.plot(np.arange(len(self.cost_his)), self.cost_his)\n plt.ylabel('Cost')\n plt.xlabel('training steps')\n plt.show()\n","attrs":{}}]},{"type":"heading","attrs":{"align":null,"level":2}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}}]}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章