機器學習(10.2)--手寫數字識別的不同算法比較(2)--KNN算法

KNN(k-NearestNeighbor)是監督學習的分類技術中最簡單的方法之一,K指k個最近的鄰居的意思,
關於KNN的詳細基本實現原理,可參考
機器學習(2)--鄰近算法(KNN)
tensorflow的實現方式,可參考:
tensorflow實例(9)--最鄰近算法實現MNIST手寫數字分類算法
關於使用的數據集,可參考

機器學習(10.1)--手寫數字識別的不同算法比較(1)--mnist數據集不同版本解析及平均灰度實踐

計算有點慢,有耐心的可以慢慢等結果,沒耐心的嘛,可以減少一些測試集的數量看個結果,正確率還是不錯,可以到94%


# -*- coding:utf-8 -*-
import pickle  
import gzip  
import numpy as np
with gzip.open(r'mnist.pkl.gz', 'rb')   as f:
    training_data, validation_data, test_data = pickle.load(f,encoding='bytes') 
test_data=list(zip(test_data[0],test_data[1]))

k=3
count=0
for index,item in enumerate(test_data):
	distance= ((training_data[0]-item[0])**2).sum(axis=1)
	k_types=np.zeros(10)# 統計最近K個點的類別數據結果 如:[0. 0. 2. 0. 0. 0. 0. 1. 0. 0.] 表示,離測試x最近的有,2個2,1個7
	for i in range(k):
		minDistIndex=np.argmin(distance) # 取得最近的距率的序號
		minType=training_data[1][minDistIndex] #取得最近距率的類別
		distance[minDistIndex]=99999999999#將最近距率設置爲一個大值,循環K次,取得前K個最大值
		k_types[minType]+=1
	if index % 100==0:
		print("計算到第%d條"%index)

	if np.argmax(k_types)==item[1]:
		count+=1
	else:
		print("第%d條,計算錯誤,預測爲%d,正確結果爲%d:"%(index,np.argmax(k_types),item[1]) )
print("計算完成,正確條數爲:%d"%count+",正確率爲"+str(round(count/len(test_data)*100,2))+"%")

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章