深度學習Q-learing算法實現
1. 問題分析
這是一個走懸崖的問題。強化學習中的主體從S出發走到G處一個回合結束,除了在邊緣以外都有上下左右四個行動,如果主體走入懸崖區域,回報爲-100,走入中間三個圓圈中的任一個,會得到-1的獎勵,走入其他所有的位置,回報都爲-5。
這是一個經典的Q-learing問題走懸崖的問題,也就是讓我們選擇的最大利益的路徑,可以將圖片轉化爲reward矩陣
[[ -5. -5. -5. -5. -5. -5. -5. -5. -5. -5. -5. -5.]
[ -5. -5. -5. -5. -5. -1. -1. -1. -5. -5. -5. -5.]
[ -5. -5. -5. -5. -5. -5. -5. -5. -5. -5. -5. -5.]
[ -5. -100. -100. -100. -100. -100. -100. -100. -100. -100. -100. 100.]]
我們的目標就是讓agent從s(3,0)到達g(3,11)尋找之間利益最大化的路徑,學習最優的策略。
2. Q—learing理論分析
在Q-learing算法中有兩個特別重要的術語:狀態(state)
,行爲(action)
,在我們這個題目中,state對應的就是我們的agent在懸崖地圖中所處的位置,action也就是agent下一步的活動,我的設定是(0, 1 ,2,3,4)對應的爲(原地不動,上,下,左,右),需要注意的事我們的next action是隨機的但是也是取決於目前的狀態(current state)。
我們的核心爲Q-learing的轉移規則(transition rule)
,我們依靠這個規則去不斷地學習,並把agent學習的經驗都儲存在Q-stable,並不斷迭代去不斷地積累經驗,最後到達我們設定的目標,這樣一個不斷試錯,學習的過程,最後到達目標的過程爲一個episode
其中表示現在狀態的state和action,表示下一個狀態的state和action,學習參數爲,越接近1代表約考慮遠期結果。
在Q-table初始化時由於agent對於周圍的環境一無所知,所以初始化爲零矩陣。
3. 算法實現
參考以下僞代碼:
具體程序如見附錄
程序的關鍵點:
- 核心代碼即爲僞代碼,但是各種方法需要自己實現,在程序中有註釋可以參考
- 需要判斷agent在一個狀態下可以使用的行動,這一點我用
valid_action(self, current_state)
實現
**發現的問題:**題目中的目標點爲G 的目標值也是爲-1,但是程序會走到這個一步但是函數沒有收斂到此處,而且由於在獎勵點收益大,所以最後的agent會收斂到獎勵點處,在三個獎勵點處來回移動。所有我將最後的目標點G的值改爲了100,函數可以收斂到此處。後來也看到文獻中的吸收目標
3. 結果展示
最後到Q-tabel矩陣由於太大放到附錄查看,但是同時爲了更加直觀的看到運行結果,
編寫了動態繪圖的程序 畫出了所有的路徑。如果需要查看動態圖片請運行程序最終結果如下圖:
從圖中可以看到agent避過了所有的懸崖,而且收穫了所有的獎勵最終到達目標。
4.附錄
程序:
#-*- utf-8 -*-
# qvkang
import numpy as np
import random
import turtle as t
class Cliff(object):
def __init__(self):
self.reward = self._reward_init()
print(self.reward)
self.row = 4
self.col = 12
self.gamma = 0.7
self.start_state = (3, 0)
self.end_state = (3, 11)
self.q_matrix = np.zeros((4,12,5))
self.main()
def _reward_init(self):
re = np.ones((4,12))*-5
# 獎勵
re[1][5:8] = np.ones((3))*-1
# 懸崖
re[3][1:11] = np.ones((10))*-100
#目標
re[3][11] = 100
return re
def valid_action(self, current_state):
# 判斷當前狀態下可以走的方向
itemrow, itemcol = current_state
valid = [0]
if(itemrow-1 >= 0): valid.append(1)
if(itemrow+1 <= self.row-1):valid.append(2)
if(itemcol-1 >= 0): valid.append(3)
if(itemcol+1 <= self.col-1): valid.append(4)
return valid
def transition(self, current_state, action):
# 從當前狀態轉移到下一個狀態
itemrow, itemcol = current_state
if (action is 0): next_state = current_state
if (action is 1): next_state = (itemrow-1, itemcol)
if (action is 2): next_state = (itemrow+1, itemcol)
if (action is 3): next_state = (itemrow, itemcol-1)
if (action is 4): next_state = (itemrow, itemcol+1)
return(next_state)
def _indextoPosition(self,index):
index += 1
itemrow = int(np.floor(index/self.col))
itemcol = index%self.col
return(itemrow, itemcol)
def _positiontoIndex(self,itemrow,itemcol):
itemindex = (itemrow)*self.col+itemcol-1
return itemindex
def getreward(self, current_state, action):
# 得到下一步的獎勵
next_state = self.transition(current_state, action)
next_row, next_col = next_state
r = self.reward[next_row, next_col]
return r
def path(self):
#繪圖path 使用turtle的繪圖庫
t.speed(10)
t.begin_fill()
paths = []
current_state = self.start_state
t.pensize(5)
t.penup()
t.goto(current_state)
t.pendown()
#移動到初始位置
paths.append(current_state)
while current_state != self.end_state:
current_row, current_col = current_state
valid_action = self.valid_action(current_state)
valid_value = [self.q_matrix[current_row][current_col][x] for x in valid_action]
max_value = max(valid_value)
action = np.where(self.q_matrix[current_row][current_col] == max_value)
print(current_state,'-------------',action)
next_state = self.transition(current_state,int(random.choice(action[0])))
paths.append(next_state)
next_row,next_col = next_state
t.goto(next_col*20, 60-next_row*20)
current_state = next_state
def main(self):
#主要循環迭代
for i in range(1000):
current_state = self.start_state
while current_state != self.end_state:
action = random.choice(self.valid_action(current_state))
next_state = self.transition(current_state, action)
future_rewards = []
for action_next in self.valid_action(next_state):
next_row, next_col = next_state
future_rewards.append(self.q_matrix[next_row][next_col][action_next])
#core trasmite rule
q_state = self.getreward(current_state, action) + self.gamma*max(future_rewards)
current_row, current_col = current_state
self.q_matrix[current_row][current_col][action] = q_state
current_state = next_state
#print(self.q_matrix)
#繪圖1000次
for i in range(1000):
self.path()
print(self.q_matrix)
if __name__ == "__main__":
Cliff()
Q-table矩陣最終結果:
[[[ -14.84480118 0. -14.06400168 0. -14.06400168]
[ -14.06400168 0. -12.94857383 -14.84480118 -12.94857383]
[ -12.94857383 0. -11.35510547 -14.06400168 -11.35510547]
[ -11.35510547 0. -9.07872209 -12.94857383 -9.07872209]
[ -9.07872209 0. -5.82674585 -11.35510547 -5.82674585]
[ -5.82674585 0. -1.1810655 -9.07872209 -5.1810655 ]
[ -5.1810655 0. -0.258665 -5.82674585 -4.258665 ]
[ -4.258665 0. 1.05905 -5.1810655 -2.94095 ]
[ -2.94095 0. 2.9415 -4.258665 2.9415 ]
[ 2.9415 0. 11.345 -2.94095 11.345 ]
[ 11.345 0. 23.35 2.9415 23.35 ]
[ 23.35 0. 40.5 11.345 0. ]]
[[ -14.06400168 -14.84480118 -14.84480118 0. -12.94857383]
[ -12.94857383 -14.06400168 -14.06400168 -14.06400168 -11.35510547]
[ -11.35510547 -12.94857383 -12.94857383 -12.94857383 -9.07872209]
[ -9.07872209 -11.35510547 -11.35510547 -11.35510547 -5.82674585]
[ -5.82674585 -9.07872209 -9.07872209 -9.07872209 -1.1810655 ]
[ -1.1810655 -5.82674585 -5.82674585 -5.82674585 -0.258665 ]
[ -0.258665 -5.1810655 -2.94095 -1.1810655 1.05905 ]
[ 1.05905 -4.258665 2.9415 -0.258665 2.9415 ]
[ 2.9415 -2.94095 11.345 1.05905 11.345 ]
[ 11.345 2.9415 23.35 2.9415 23.35 ]
[ 23.35 11.345 40.5 11.345 40.5 ]
[ 40.5 23.35 65. 23.35 0. ]]
[[ -14.84480118 -14.06400168 -15.39136082 0. -14.06400168]
[ -14.06400168 -12.94857383 -109.84480118 -14.84480118 -12.94857383]
[ -12.94857383 -11.35510547 -109.06400168 -14.06400168 -11.35510547]
[ -11.35510547 -9.07872209 -107.94857383 -12.94857383 -9.07872209]
[ -9.07872209 -5.82674585 -106.35510547 -11.35510547 -5.82674585]
[ -5.82674585 -1.1810655 -104.0787221 -9.07872209 -2.94095 ]
[ -2.94095 -0.258665 -102.058665 -5.82674585 2.9415 ]
[ 2.9415 1.05905 -97.94095 -2.94095 11.345 ]
[ 11.345 2.9415 -92.0585 2.9415 23.35 ]
[ 23.35 11.345 -83.655 11.345 40.5 ]
[ 40.5 23.35 -30. 23.35 65. ]
[ 65. 40.5 100. 40.5 0. ]]
[[ -15.39136082 -14.84480118 0. 0. -109.84480118]
[-109.84480118 -14.06400168 0. -15.39136082 -109.06400168]
[-109.06400168 -12.94857383 0. -109.84480118 -107.94857383]
[-107.94857383 -11.35510547 0. -109.06400168 -106.35510547]
[-106.35510547 -9.07872209 0. -107.94857383 -104.0787221 ]
[-104.0787221 -5.82674585 0. -106.35510547 -102.058665 ]
[-102.058665 -2.94095 0. -104.0787221 -97.94095 ]
[ -97.94095 2.9415 0. -102.058665 -92.0585 ]
[ -92.0585 11.345 0. -97.94095 -83.655 ]
[ -83.655 23.35 0. -92.0585 -30. ]
[ -30. 40.5 0. -83.655 100. ]
[ 0. 0. 0. 0. 0. ]]]