深度學習Q-learing算法實現

深度學習Q-learing算法實現


1. 問題分析

在這裏插入圖片描述

這是一個走懸崖的問題。強化學習中的主體從S出發走到G處一個回合結束,除了在邊緣以外都有上下左右四個行動,如果主體走入懸崖區域,回報爲-100,走入中間三個圓圈中的任一個,會得到-1的獎勵,走入其他所有的位置,回報都爲-5。

這是一個經典的Q-learing問題走懸崖的問題,也就是讓我們選擇的最大利益的路徑,可以將圖片轉化爲reward矩陣

[[  -5.   -5.   -5.   -5.   -5.   -5.   -5.   -5.   -5.   -5.   -5.   -5.]
 [  -5.   -5.   -5.   -5.   -5.   -1.   -1.   -1.   -5.   -5.   -5.   -5.]
 [  -5.   -5.   -5.   -5.   -5.   -5.   -5.   -5.   -5.   -5.   -5.   -5.]
 [  -5. -100. -100. -100. -100. -100. -100. -100. -100. -100. -100.  100.]]

我們的目標就是讓agent從s(3,0)到達g(3,11)尋找之間利益最大化的路徑,學習最優的策略。

2. Q—learing理論分析

在Q-learing算法中有兩個特別重要的術語:狀態(state),行爲(action),在我們這個題目中,state對應的就是我們的agent在懸崖地圖中所處的位置,action也就是agent下一步的活動,我的設定是(0, 1 ,2,3,4)對應的爲(原地不動,上,下,左,右),需要注意的事我們的next action是隨機的但是也是取決於目前的狀態(current state)。

我們的核心爲Q-learing的轉移規則(transition rule),我們依靠這個規則去不斷地學習,並把agent學習的經驗都儲存在Q-stable,並不斷迭代去不斷地積累經驗,最後到達我們設定的目標,這樣一個不斷試錯,學習的過程,最後到達目標的過程爲一個episode
Q(s,a)=R(s,a)+γmax{Q(s~,a~)}Q(s,a) = R(s,a)+\gamma *max \lbrace Q(\tilde{s},\tilde{a}) \rbrace
其中s,as,a表示現在狀態的state和action,s~,a~\tilde{s},\tilde{a}表示下一個狀態的state和action,學習參數爲0<γ<10<\gamma<1,越接近1代表約考慮遠期結果。
在Q-table初始化時由於agent對於周圍的環境一無所知,所以初始化爲零矩陣。

3. 算法實現

參考以下僞代碼:
在這裏插入圖片描述
具體程序如見附錄
程序的關鍵點:

  1. 核心代碼即爲僞代碼,但是各種方法需要自己實現,在程序中有註釋可以參考
  2. 需要判斷agent在一個狀態下可以使用的行動,這一點我用valid_action(self, current_state)實現

**發現的問題:**題目中的目標點爲G 的目標值也是爲-1,但是程序會走到這個一步但是函數沒有收斂到此處,而且由於在獎勵點收益大,所以最後的agent會收斂到獎勵點處,在三個獎勵點處來回移動。所有我將最後的目標點G的值改爲了100,函數可以收斂到此處。後來也看到文獻中的吸收目標

3. 結果展示

最後到Q-tabel矩陣由於太大放到附錄查看,但是同時爲了更加直觀的看到運行結果,
編寫了動態繪圖的程序 畫出了所有的路徑。如果需要查看動態圖片請運行程序最終結果如下圖:
在這裏插入圖片描述
從圖中可以看到agent避過了所有的懸崖,而且收穫了所有的獎勵最終到達目標。

4.附錄

程序:

	#-*- utf-8 -*-
	# qvkang
	import numpy as np
	import random
	import turtle as t
	class Cliff(object):
	    def __init__(self):
	        self.reward = self._reward_init()
	        print(self.reward)
	        self.row = 4
	        self.col = 12
	        self.gamma = 0.7
	        self.start_state = (3, 0)
	        self.end_state = (3, 11)
	        self.q_matrix = np.zeros((4,12,5))
	        self.main()
	
	    def _reward_init(self):
	        re = np.ones((4,12))*-5
	        # 獎勵
	        re[1][5:8] = np.ones((3))*-1
	        # 懸崖
	        re[3][1:11] = np.ones((10))*-100
	        #目標
	        re[3][11] = 100
	        return re
	
	    def valid_action(self, current_state):
	        # 判斷當前狀態下可以走的方向
	        itemrow, itemcol = current_state
	        valid = [0]
	        if(itemrow-1 >= 0): valid.append(1)
	        if(itemrow+1 <= self.row-1):valid.append(2)
	        if(itemcol-1 >= 0): valid.append(3)
	        if(itemcol+1 <= self.col-1): valid.append(4)
	        return valid
	
	    def transition(self, current_state, action):
	        # 從當前狀態轉移到下一個狀態
	        itemrow, itemcol = current_state
	        if (action is 0):   next_state = current_state
	        if (action is 1):   next_state = (itemrow-1, itemcol)
	        if (action is 2):   next_state = (itemrow+1, itemcol)
	        if (action is 3):   next_state = (itemrow, itemcol-1)
	        if (action is 4):   next_state = (itemrow, itemcol+1)
	        return(next_state)
	    def _indextoPosition(self,index):
	        index += 1
	        itemrow = int(np.floor(index/self.col))
	        itemcol = index%self.col
	        return(itemrow, itemcol)
	
	    def _positiontoIndex(self,itemrow,itemcol):
	        itemindex = (itemrow)*self.col+itemcol-1
	        return itemindex
	    
	    def getreward(self, current_state, action):
	        # 得到下一步的獎勵
	        next_state = self.transition(current_state, action)
	        next_row, next_col = next_state
	        r = self.reward[next_row, next_col]
	        return r
	    def path(self):
	        #繪圖path  使用turtle的繪圖庫
	        t.speed(10)
	        t.begin_fill()
	        paths = []
	        current_state = self.start_state
	        t.pensize(5)
	        t.penup()
	        t.goto(current_state)
	        t.pendown()
	        #移動到初始位置
	        paths.append(current_state)
	        while current_state != self.end_state:
	            current_row, current_col = current_state
	            valid_action = self.valid_action(current_state)
	            valid_value = [self.q_matrix[current_row][current_col][x] for x in valid_action]
	            max_value = max(valid_value)
	            action = np.where(self.q_matrix[current_row][current_col] == max_value)
	            print(current_state,'-------------',action)
	            next_state = self.transition(current_state,int(random.choice(action[0])))
	            paths.append(next_state)
	            next_row,next_col = next_state
	            t.goto(next_col*20, 60-next_row*20)
	            current_state = next_state
	
	    def main(self):
	        #主要循環迭代
	        for i in range(1000):
	            current_state = self.start_state
	            while current_state != self.end_state:
	                action = random.choice(self.valid_action(current_state))
	                next_state = self.transition(current_state, action)
	                future_rewards = []
	                for action_next in self.valid_action(next_state):
	                    next_row, next_col = next_state
	                    future_rewards.append(self.q_matrix[next_row][next_col][action_next])
	                #core trasmite rule
	                q_state = self.getreward(current_state, action) + self.gamma*max(future_rewards)
	                current_row, current_col = current_state
	                self.q_matrix[current_row][current_col][action] = q_state
	                current_state = next_state
	                #print(self.q_matrix)
	        #繪圖1000次
	        for i in range(1000):
	            self.path()
	        print(self.q_matrix)
	
	if __name__ == "__main__":
	    Cliff()
	
	
	    
	

Q-table矩陣最終結果:

	[[[ -14.84480118    0.          -14.06400168    0.          -14.06400168]
	  [ -14.06400168    0.          -12.94857383  -14.84480118  -12.94857383]
	  [ -12.94857383    0.          -11.35510547  -14.06400168  -11.35510547]
	  [ -11.35510547    0.           -9.07872209  -12.94857383   -9.07872209]
	  [  -9.07872209    0.           -5.82674585  -11.35510547   -5.82674585]
	  [  -5.82674585    0.           -1.1810655    -9.07872209   -5.1810655 ]
	  [  -5.1810655     0.           -0.258665     -5.82674585   -4.258665  ]
	  [  -4.258665      0.            1.05905      -5.1810655    -2.94095   ]
	  [  -2.94095       0.            2.9415       -4.258665      2.9415    ]
	  [   2.9415        0.           11.345        -2.94095      11.345     ]
	  [  11.345         0.           23.35          2.9415       23.35      ]
	  [  23.35          0.           40.5          11.345         0.        ]]
	
	 [[ -14.06400168  -14.84480118  -14.84480118    0.          -12.94857383]
	  [ -12.94857383  -14.06400168  -14.06400168  -14.06400168  -11.35510547]
	  [ -11.35510547  -12.94857383  -12.94857383  -12.94857383   -9.07872209]
	  [  -9.07872209  -11.35510547  -11.35510547  -11.35510547   -5.82674585]
	  [  -5.82674585   -9.07872209   -9.07872209   -9.07872209   -1.1810655 ]
	  [  -1.1810655    -5.82674585   -5.82674585   -5.82674585   -0.258665  ]
	  [  -0.258665     -5.1810655    -2.94095      -1.1810655     1.05905   ]
	  [   1.05905      -4.258665      2.9415       -0.258665      2.9415    ]
	  [   2.9415       -2.94095      11.345         1.05905      11.345     ]
	  [  11.345         2.9415       23.35          2.9415       23.35      ]
	  [  23.35         11.345        40.5          11.345        40.5       ]
	  [  40.5          23.35         65.           23.35          0.        ]]
	
	 [[ -14.84480118  -14.06400168  -15.39136082    0.          -14.06400168]
	  [ -14.06400168  -12.94857383 -109.84480118  -14.84480118  -12.94857383]
	  [ -12.94857383  -11.35510547 -109.06400168  -14.06400168  -11.35510547]
	  [ -11.35510547   -9.07872209 -107.94857383  -12.94857383   -9.07872209]
	  [  -9.07872209   -5.82674585 -106.35510547  -11.35510547   -5.82674585]
	  [  -5.82674585   -1.1810655  -104.0787221    -9.07872209   -2.94095   ]
	  [  -2.94095      -0.258665   -102.058665     -5.82674585    2.9415    ]
	  [   2.9415        1.05905     -97.94095      -2.94095      11.345     ]
	  [  11.345         2.9415      -92.0585        2.9415       23.35      ]
	  [  23.35         11.345       -83.655        11.345        40.5       ]
	  [  40.5          23.35        -30.           23.35         65.        ]
	  [  65.           40.5         100.           40.5           0.        ]]
	
	 [[ -15.39136082  -14.84480118    0.            0.         -109.84480118]
	  [-109.84480118  -14.06400168    0.          -15.39136082 -109.06400168]
	  [-109.06400168  -12.94857383    0.         -109.84480118 -107.94857383]
	  [-107.94857383  -11.35510547    0.         -109.06400168 -106.35510547]
	  [-106.35510547   -9.07872209    0.         -107.94857383 -104.0787221 ]
	  [-104.0787221    -5.82674585    0.         -106.35510547 -102.058665  ]
	  [-102.058665     -2.94095       0.         -104.0787221   -97.94095   ]
	  [ -97.94095       2.9415        0.         -102.058665    -92.0585    ]
	  [ -92.0585       11.345         0.          -97.94095     -83.655     ]
	  [ -83.655        23.35          0.          -92.0585      -30.        ]
	  [ -30.           40.5           0.          -83.655       100.        ]
	  [   0.            0.            0.            0.            0.        ]]]
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章