求助-強化學習基礎-K-搖臂老虎機Python

按照周志華西瓜書第16章K-搖臂賭博機的僞碼編的程序:

# -*- coding: utf-8 -*-
"""
e貪心和Softmax
2-搖臂賭博機
搖臂1:0.4概率獎勵1,0.6-0
搖臂2:0.2-1,     0.8-1

@author: y1064
"""
import numpy as np
import matplotlib.pyplot as plt

K = 2 # 搖臂數
R = [[1,0],[1,0]] # 獎賞函數
probs = [[0.4,0.6],[0.2,0.8]] # 對應的概率

T = 5000 # 嘗試次數
e = 0.1  # 探索概率
def e_greedy(e):
    """伊普西龍貪心"""
    Q = [0,0]   # 記錄搖臂的平均獎賞
    count = [0,0]  # 記錄搖臂的探索次數
    r = 0
    r_list = []
    for t in range(T):
        if np.random.uniform()<e:
            k = np.random.choice(range(K)) # 從搖臂中均勻地選擇 
        else:
            k = np.argmax(Q)
        v = np.random.choice(R[k],p=probs[k])
        r+=v
        Q[k] = (Q[k]*count[k]+v)/(count[k]+1) # 更新平均獎賞
        count[k]+=1
        r_list.append(r/(t+1))  # 平均累積獎賞
    return r_list

def softmax(tau):
    """softmax """
    Q = [0,0]
    count = [0,0]
    r = 0
    r_list = []
    for t in range(T):
        sum_p=sum([np.exp(i/tau) for i in Q])
        P=[np.exp(i/tau)/sum_p for i in Q]
        k=int(np.random.choice([0,1],p=P))

        v = np.random.choice(R[k], p=probs[k])
        r+=v
        Q[k] = (Q[k]*count[k]+v)/(count[k]+1) # 更新平均獎賞
        count[k]+=1
        r_list.append(r/(t+1))  # 平均累積獎賞
    return r_list


plt.plot(e_greedy(e=0.1),label='e-greedy,e=0.1')
plt.plot(e_greedy(e=0.01),label='e-greedy,e=0.01')
plt.plot(softmax(tau=0.01),label='softmax,tau=0.01')
plt.plot(softmax(tau=0.1),label='softmax,tau=0.1')
plt.legend()

在這裏插入圖片描述
結果跟書中差距很大,暫時看不出代碼哪裏錯了,求助!剛開始強化學習,積極性就要被無情打消了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章