按照周志華西瓜書第16章K-搖臂賭博機的僞碼編的程序:
# -*- coding: utf-8 -*-
"""
e貪心和Softmax
2-搖臂賭博機
搖臂1:0.4概率獎勵1,0.6-0
搖臂2:0.2-1, 0.8-1
@author: y1064
"""
import numpy as np
import matplotlib.pyplot as plt
K = 2 # 搖臂數
R = [[1,0],[1,0]] # 獎賞函數
probs = [[0.4,0.6],[0.2,0.8]] # 對應的概率
T = 5000 # 嘗試次數
e = 0.1 # 探索概率
def e_greedy(e):
"""伊普西龍貪心"""
Q = [0,0] # 記錄搖臂的平均獎賞
count = [0,0] # 記錄搖臂的探索次數
r = 0
r_list = []
for t in range(T):
if np.random.uniform()<e:
k = np.random.choice(range(K)) # 從搖臂中均勻地選擇
else:
k = np.argmax(Q)
v = np.random.choice(R[k],p=probs[k])
r+=v
Q[k] = (Q[k]*count[k]+v)/(count[k]+1) # 更新平均獎賞
count[k]+=1
r_list.append(r/(t+1)) # 平均累積獎賞
return r_list
def softmax(tau):
"""softmax """
Q = [0,0]
count = [0,0]
r = 0
r_list = []
for t in range(T):
sum_p=sum([np.exp(i/tau) for i in Q])
P=[np.exp(i/tau)/sum_p for i in Q]
k=int(np.random.choice([0,1],p=P))
v = np.random.choice(R[k], p=probs[k])
r+=v
Q[k] = (Q[k]*count[k]+v)/(count[k]+1) # 更新平均獎賞
count[k]+=1
r_list.append(r/(t+1)) # 平均累積獎賞
return r_list
plt.plot(e_greedy(e=0.1),label='e-greedy,e=0.1')
plt.plot(e_greedy(e=0.01),label='e-greedy,e=0.01')
plt.plot(softmax(tau=0.01),label='softmax,tau=0.01')
plt.plot(softmax(tau=0.1),label='softmax,tau=0.1')
plt.legend()
結果跟書中差距很大,暫時看不出代碼哪裏錯了,求助!剛開始強化學習,積極性就要被無情打消了。