Python 簡單實現KNN算法

數據集是自己下載的mnist的手寫識別的數據,有一個train.csv文本,一個test.csv測試文本,還有一個submission.csv文本(存放的是test.csv的標籤),不多說了,KNN原理很簡單,直接上代碼吧


#autor:zhouchao
#date:2017-12-07 11:13
#description:use knn to recognize num

import numpy as np
from numpy import *
import operator
from numpy import random  

def load_train_data(path):
	train=np.loadtxt(path,delimiter=",", skiprows=0)
	vec=train[:,1:]
	labels=train[:,0:1].tolist()
	print type(labels)
	return vec,labels
def predict(line,vec,labels):
	numSamples = vec.shape[0]
	diff = tile(line, (numSamples, 1)) - vec
	squaredDiff = diff ** 2
	squaredDist = sum(squaredDiff, axis = 1)
	distance = squaredDist ** 0.5
	sortedDistIndices = argsort(distance)
	
	classCount = {}
	for i in xrange(20):
		voteLabel = labels[sortedDistIndices[i]][0]
		classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
	maxCount = 0
	for key, value in classCount.items():
		if value > maxCount:
			maxCount = value
			maxIndex = key
	return maxIndex 
if __name__=="__main__":
	vec,labels=load_train_data("../../data/handwrite/train.csv")
	f=open("../../data/handwrite/test.txt")
	for line in f.readlines():
		nums = line.split(",")
		nums = [int(x) for x in nums ]
		matrix = np.array(nums)
		print predict(matrix,vec,labels)




發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章