本文內容源自百度強化學習 7 日入門課程學習整理
感謝百度 PARL 團隊李科澆老師的課程講解
文章目錄
一、爲什麼要引入神經網絡
Q 表只能解決少量狀態的問題,如果狀態數量上漲,那我們面對的可能性呈現指數上漲,這樣的話Q表格就沒有這個處理能力了
比如:
- 國際象棋:種狀態
- 圍棋:種狀態
- 連續操作的問題:不可數狀態(不如彎曲角度)
- (整個宇宙的原子數量預估:)
Q表格不行的時候,我們可以採用:值函數(Q函數)近似
Q表格的作用在於:輸入狀態和動作,輸出Q值
那我們可以用一個 “帶參數” 的 Q 函數來進行替代:
- 多項式函數
- 神經網絡
不同的近似方式:
- 輸入狀態 s 和動作 a,輸出一個q 值
- 輸入狀態 s,輸出多個 q 值(不同動作所對應的 q 值)
Q表格方法的缺點:
- 表格佔用大量內存
- 表格大的時候,查表效率低
值函數近似的優點:
- 僅需存儲有限數量的參數
- 狀態泛化,相似的狀態可以輸出一樣
神經網絡可以逼近任意連續的函數
- 比如 CNN 在強化學習中引入後,可以讓強化學習算法根據圖片做出決策(輸入圖片,輸出動作)
- 神經網絡的原理在於,定義 cost 爲真實值和預測值之間的差距,然後用梯度下降來最小化 cost
二、DQN 算法
DQN 是使用神經網絡解決強化學習問題最經典的算法
該算法由谷歌的 DeepMind 團隊在 2015 年提出
《Human-level control through deep reinforcement learning》這篇論文被髮表在了 Nature 雜誌上
通過高維度的輸入信息(像素級別的圖像),使用了神經網絡的 DQN 在 49 個 Atari 遊戲中,有 30 個超越了人類水平
使用神經網絡代替Q表格以後:
- 輸入可以是一個向量,包含各種值(比如四軸飛行器的高度,角度,轉速等)
- 輸入可以是一個圖片,包含各個像素點的信息
- 輸出直接是對應的動作
2.1 DQN 約等於 Q-learning + 神經網絡
- 輸入 狀態 s
- 輸出 q 向量,如果一個狀態下有 5 種動作,那 q 就是 5 維的
- 然後根據我們具體的動作選擇,確定 q 值
- 然後要讓輸出的 q 值,逼近 目標 q 值(target_q)
- target_q 的計算公式就是 Q-learning 的方法:
- 神經網絡輸出的預測值:
- 計算預測值和目標值的均方差(即 loss):
- 使用優化器,最小化 loss
2.2 DQN 的兩大創新
神經網絡中由於引入了非線形函數,比如 “relu”
所以在理論上,無法證明訓練之後一定會收斂
於是 DQN 提出兩大創新,使得訓練更有效率,也更穩定
2.2.1 經驗回放 Experience replay
作用:
- 解決序列決策的樣本關聯性問題
- 解決樣本利用率低的問題
問題來源:
- 在監督學習中,訓練樣本是獨立的
- 但是在強化學習中,輸入的是狀態值,每一個狀態都是連續發生,前後狀態相互關聯,所以樣本之間具有關聯性
解決方案:
- 需要打亂,或者切斷輸入樣本之間的聯繫
- 這裏用到了 Q-learning 的 Off-Policy 特點
- 先存儲一批經驗數據
- 然後打亂
- 從中隨機選取一個小的 batch 來更新網絡
- 這樣就打破了樣本間的相關性,同時使得網絡更有效率
Off-Policy 在經驗回放中的作用:
- 設置經驗池:是一個固定長度的隊列
- 一條經驗指的是:一組 ,,,
- 每拿到一條經驗就往經驗池進行存儲
- 滿了以後,彈出舊的經驗
- 從經驗池中隨機抽取一個 batch
- 去更新 Q 值(這裏就是更新神經網絡的係數)
優點:
- 由於經驗池中的數據有可能被重複抽取到,所以相當於經驗可以重複利用,即提高了樣本的利用率
- 另外由於是隨機抽取,所以打亂了樣本間的相關性
2.2.2 固定 Q 目標 Fixed Q target
作用:
- 解決算法訓練不穩定的問題
問題來源:
- 監督學習中,我們預測值要去逼近真實值,而真實值是固定不變的
- 但是在 DQN 中,輸入狀態輸出預測的Q,要逼近的是目標Q
- 其中 也是神經網絡的輸出,而神經網絡權重係數一旦更新以後,這個值也會發生變化
- 所以只要我們更新一次神經網絡,那目標 Q 值也就會不斷變化
解決方法:
- 我們要想辦法把 Q-target 值固定住
- 也就是我們要把輸出 Q-target 的神經網絡參數固定一段時間
- 然後過一段時間以後,再用最新的學習後的神經網絡參數,刷新這個神經網絡
2.3 DQN 流程框架圖
Model:
- 代替了 Q 表
- 輸入 S 輸出 不同動作對應的 Q(預測值)給 Agent
- 同時設定一個固定一段時間的神經網絡用於輸出 Q_target
- 過一段時間更新該固定網絡參數
引入神經網絡的問題解決:
- 經驗回放
- 固定目標值
Agent:
- 和環境交互
- 交互數據(經驗)存儲到經驗池
- 提取經驗池數據,更新 Model 參數(利用最小化 預測值和目標值之間的 loss)——DQN最核心部分
2.4 PARL 的 DQN 框架
分爲 model,algorithm,agent 這 3 個部分
- model:用來定義神經網絡部分的網絡結構,同時實現模型複製
- algorithm:實現具體算法,如何定義損失函數,更新 model,主要包含了 predict() 和 learn() 兩個函數
- agent:負責和環境做交互,數據預處理,構建計算圖
總體抽象來說:
- Agent 包含了 Algorithm 和 Model
- Algorithm 包含了 Model
PARL 常用的 API:
- agent.save():保存模型
- agent.restore():加載模型
- model.sync_weights_to():把當前模型的參數同步到另一個模型去
- model.parameters():返回一個 list,包含模型所有參數的名稱
- model.get_weights():返回一個 list,包含模型的所有參數
- model.set_weights():設置模型參數
PARL 裏面打印日誌的工具:
- parl.utils.logger:打印日誌,涵蓋時間,代碼所在文件及行數,方便記錄訓練時間
PARL 的 API 文檔地址:
https://parl.readthedocs.io/en/latest/model.html