強化學習入門(三)將神經網絡引入強化學習,經典算法 DQN

本文內容源自百度強化學習 7 日入門課程學習整理
感謝百度 PARL 團隊李科澆老師的課程講解

一、爲什麼要引入神經網絡

Q 表只能解決少量狀態的問題,如果狀態數量上漲,那我們面對的可能性呈現指數上漲,這樣的話Q表格就沒有這個處理能力了

比如:

  • 國際象棋:104710^{47}種狀態
  • 圍棋:1017010^{170}種狀態
  • 連續操作的問題:不可數狀態(不如彎曲角度)
  • (整個宇宙的原子數量預估:108010^{80}

Q表格不行的時候,我們可以採用:值函數(Q函數)近似

Q表格的作用在於:輸入狀態和動作,輸出Q值

那我們可以用一個 “帶參數” 的 Q 函數來進行替代:qπ(s,a)  q^(s,a,w)q^π(s,a)\ \approx\ \hat{q}(s,a,\textbf{w})

  • 多項式函數
  • 神經網絡

不同的近似方式:

  • 輸入狀態 s 和動作 a,輸出一個q 值 q^(s,a,w)\hat{q}(s,a,\textbf{w})
  • 輸入狀態 s,輸出多個 q 值(不同動作所對應的 q 值) q^(s,a1,w) ... q^(s,am,w)\hat{q}(s,a_1,\textbf{w})\ ...\ \hat{q}(s,a_m,\textbf{w})

Q表格方法的缺點:

  • 表格佔用大量內存
  • 表格大的時候,查表效率低

值函數近似的優點:

  • 僅需存儲有限數量的參數
  • 狀態泛化,相似的狀態可以輸出一樣

神經網絡可以逼近任意連續的函數

  • 比如 CNN 在強化學習中引入後,可以讓強化學習算法根據圖片做出決策(輸入圖片,輸出動作)
  • 神經網絡的原理在於,定義 cost 爲真實值和預測值之間的差距,然後用梯度下降來最小化 cost

二、DQN 算法

DQN 是使用神經網絡解決強化學習問題最經典的算法

該算法由谷歌的 DeepMind 團隊在 2015 年提出

《Human-level control through deep reinforcement learning》這篇論文被髮表在了 Nature 雜誌上

通過高維度的輸入信息(像素級別的圖像),使用了神經網絡的 DQN 在 49 個 Atari 遊戲中,有 30 個超越了人類水平

使用神經網絡代替Q表格以後:

  • 輸入可以是一個向量,包含各種值(比如四軸飛行器的高度,角度,轉速等)
  • 輸入可以是一個圖片,包含各個像素點的信息
  • 輸出直接是對應的動作

2.1 DQN 約等於 Q-learning + 神經網絡

  • 輸入 狀態 s
  • 輸出 q 向量,如果一個狀態下有 5 種動作,那 q 就是 5 維的
  • 然後根據我們具體的動作選擇,確定 q 值
  • 然後要讓輸出的 q 值,逼近 目標 q 值(target_q)
    • target_q 的計算公式就是 Q-learning 的方法:qπ(s,a) = r + γmaxaq^(s,a,w)q_π(s,a)\ =\ r\ +\ γ\max\limits_{a'}\hat{q}(s',a',\textbf{w})
    • 神經網絡輸出的預測值:q^(s,a,w)\hat{q}(s,a,\textbf{w})
    • 計算預測值和目標值的均方差(即 loss):Eπ[(qπ(s,a)  q^(s,a,w))2]E_π[(q_π(s,a)\ - \ \hat{q}(s,a,\textbf{w}))^2]
  • 使用優化器,最小化 loss
    在這裏插入圖片描述

2.2 DQN 的兩大創新

神經網絡中由於引入了非線形函數,比如 “relu”

所以在理論上,無法證明訓練之後一定會收斂

於是 DQN 提出兩大創新,使得訓練更有效率,也更穩定

2.2.1 經驗回放 Experience replay

作用:

  • 解決序列決策的樣本關聯性問題
  • 解決樣本利用率低的問題

問題來源:

  • 在監督學習中,訓練樣本是獨立的
  • 但是在強化學習中,輸入的是狀態值,每一個狀態都是連續發生,前後狀態相互關聯,所以樣本之間具有關聯性

解決方案:

  • 需要打亂,或者切斷輸入樣本之間的聯繫
  • 這裏用到了 Q-learning 的 Off-Policy 特點
  • 先存儲一批經驗數據
  • 然後打亂
  • 從中隨機選取一個小的 batch 來更新網絡
  • 這樣就打破了樣本間的相關性,同時使得網絡更有效率

Off-Policy 在經驗回放中的作用:

  • 設置經驗池:是一個固定長度的隊列
  • 一條經驗指的是:一組 sts_tata_trt+1r_{t+1}st+1s_{t+1}
  • 每拿到一條經驗就往經驗池進行存儲
  • 滿了以後,彈出舊的經驗
  • 從經驗池中隨機抽取一個 batch
  • 去更新 Q 值(這裏就是更新神經網絡的係數)

在這裏插入圖片描述

優點:

  • 由於經驗池中的數據有可能被重複抽取到,所以相當於經驗可以重複利用,即提高了樣本的利用率
  • 另外由於是隨機抽取,所以打亂了樣本間的相關性

2.2.2 固定 Q 目標 Fixed Q target

作用:

  • 解決算法訓練不穩定的問題

問題來源:

  • 監督學習中,我們預測值要去逼近真實值,而真實值是固定不變的
  • 但是在 DQN 中,輸入狀態輸出預測的Q,要逼近的是目標Q
  • Q_target = r + γ max Q(s,a,θ)Q\_target\ =\ r\ +\ γ\ max\ Q(s',a',θ)
  • 其中 max Q(s,a,θ)max\ Q(s',a',θ) 也是神經網絡的輸出,而神經網絡權重係數一旦更新以後,這個值也會發生變化
  • 所以只要我們更新一次神經網絡,那目標 Q 值也就會不斷變化

解決方法:

  • 我們要想辦法把 Q-target 值固定住
  • 也就是我們要把輸出 Q-target 的神經網絡參數固定一段時間
  • 然後過一段時間以後,再用最新的學習後的神經網絡參數,刷新這個神經網絡

2.3 DQN 流程框架圖

在這裏插入圖片描述
Model:

  • 代替了 Q 表
  • 輸入 S 輸出 不同動作對應的 Q(預測值)給 Agent
  • 同時設定一個固定一段時間的神經網絡用於輸出 Q_target
  • 過一段時間更新該固定網絡參數

引入神經網絡的問題解決:

  • 經驗回放
  • 固定目標值

Agent:

  • 和環境交互
  • 交互數據(經驗)存儲到經驗池
  • 提取經驗池數據,更新 Model 參數(利用最小化 預測值和目標值之間的 loss)——DQN最核心部分

2.4 PARL 的 DQN 框架

在這裏插入圖片描述
分爲 model,algorithm,agent 這 3 個部分

  • model:用來定義神經網絡部分的網絡結構,同時實現模型複製
  • algorithm:實現具體算法,如何定義損失函數,更新 model,主要包含了 predict() 和 learn() 兩個函數
  • agent:負責和環境做交互,數據預處理,構建計算圖

在這裏插入圖片描述
總體抽象來說:

  • Agent 包含了 Algorithm 和 Model
  • Algorithm 包含了 Model

PARL 常用的 API:

  • agent.save():保存模型
  • agent.restore():加載模型
  • model.sync_weights_to():把當前模型的參數同步到另一個模型去
  • model.parameters():返回一個 list,包含模型所有參數的名稱
  • model.get_weights():返回一個 list,包含模型的所有參數
  • model.set_weights():設置模型參數

PARL 裏面打印日誌的工具:

  • parl.utils.logger:打印日誌,涵蓋時間,代碼所在文件及行數,方便記錄訓練時間

PARL 的 API 文檔地址:

https://parl.readthedocs.io/en/latest/model.html

三、DQN 算法代碼詳解

強化學習算法 DQN 解決 CartPole 問題,代碼逐條詳解

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章