PLA
15
注意: 在原有4維數據的基礎上有增加一維作爲偏移量
import numpy as np
# 數據處理
def getData(file_name):
f = open(file_name)
data = []
for line in f.readlines():
line = [float(v) for v in line.split()]
line.insert(0,1.0)
line = tuple(line)
data.append(line)
return np.array(data)
def sign(y):
if y<=0 : return -1
else : return 1
data = getData("./data.txt")
print(data.shape)
train_time = 0
# np.random.shuffle(data) # 打亂數據, 打亂數據後結果不同, 第二題
w = np.array([0]*5,dtype=float)
while True:
isFinish = True
for index in range(data.shape[0]):
y = w.dot(data[index][:5])
if sign(y) != data[index][-1]:
w += data[index][:5] * data[index][-1]
# w += data[index][:5] * data[index][-1] * 0.5 # 第三題
isFinish = False
train_time += 1
print("update paramter:",train_time)
if isFinish == True:
break
18
pocket: 咋所有的錯誤點中進行參數更新, 下端代碼中,每次迭代隨機選取50個錯誤點,進行參數更新,如果更後的參數使得錯誤率降低,則更改最優參數爲當前參數,否在則不進行修改。
import numpy as np
# 數據處理
def getData(file_name):
f = open(file_name)
data = []
for line in f.readlines():
line = [float(v) for v in line.split()]
line.insert(0,1.0)
line = tuple(line)
data.append(line)
return np.array(data)
# sign
def sign(yhat):
yhat = np.sign(yhat)
yhat[np.where(yhat==0)] = -1
return yhat
# 計算錯誤率
def err_rate(yhat,data):
return np.sum(yhat != data[:,-1])/yhat.shape[0]
def Pla_train(data, w, iternum):
yhat = sign(data[:,:5].dot(w.T))
errodle = err_rate(yhat,data)
best_w = w.copy()
for t in range(iternum):
index = np.where(yhat != data[:,-1])[0]
#print(index)
if not index.any():
break
# 隨機挑選錯誤的進行更新, 打亂,挑選第一個進行更新
pos = index[np.random.permutation(len(index))[0]]
# 更新參數
w += data[pos][:5] * data[pos][-1]
# 新的yhat
yhat = sign(data[:,:5].dot(w.T))
errnew = err_rate(yhat,data)
if errnew < errodle:
best_w = w.copy()
return best_w, errnew
# 18
def train_18():
# 讀入數據
data = getData("./train.txt")
# 初始化參數
w = np.array([0,0,0,0,0],dtype=np.float)
# 讀入測試數據
data_test = getData("./test.txt")
err_test = 0
for i in range(2000):
w, err_r = Pla_train(data,w,50)
if i % 100 ==0:
print("當前訓練錯誤率:",err_r)
# 輸出測試錯誤率
yhat_test = sign(data_test[:,:5].dot(w.T))
err_test += err_rate(yhat_test,data_test)
print("平均測試集錯誤率:",err_test/2000)
# 最終測試錯誤率
yhat_last = sign(data_test[:,:5].dot(w.T))
print("最終測試集錯誤率:", err_rate(yhat_last,data_test))
if __name__ == "__main__":
train_18()