- 前提基礎:
KNN基本原理,本文參考 李航博士 著《統計學習方法》
- 距離度量方式
- 歐式距離
- 曼哈頓距離
- 夾角餘弦
- 切比雪夫距離
- 馬氏距離
- k值的選擇
- 算法優缺點
- 優點:精度高、對異常值不敏感、無數據輸入假定;
- 缺點:計算複雜度高、空間複雜度高;
適用數據範圍:數值型和標稱型
- 代碼如下:
#encoding=utf-8
import pandas as pd
import numpy as np
import time
from sklearn.cross_validation import train_test_split
from sklearn.metrics import accuracy_score
class KNN(object):
def __init__(self):
self.k = 10
def predict(self, testset, trainset, train_labels):
predict_ = []
count = 0
for test_vec in testset:
print count
count += 1
knn_list = [] #當前k個最近鄰居
max_index = -1 #當前k個最近鄰居中距離最遠點的座標
max_dist = 0 #當前k個最近鄰居中距離最遠點的距離
# 先將前k個點放入k個最近鄰居中,充滿knn_list
for i in xrange(self.k):
label = train_labels[i]
train_vec = trainset[i]
dist = np.linalg.norm(train_vec - test_vec)
knn_list.append((dist, label))
# 剩下的點
for i in xrange(self.k, len(train_labels)):
label = train_labels[i]
train_vec = trainset[i]
dist = np.linalg.norm(train_vec - test_vec)
# 尋找k個鄰近點距離最遠的點
if max_index < 0:
for j in xrange(10):
if max_dist < knn_list[j][0]:
max_index = j
max_dist = knn_list[max_index][0]
if dist < max_dist:
knn_list[max_index] = (dist, label)
max_index = -1
max_dist = 0
class_total = 10
class_count = [0 for i in xrange(class_total)]
for dist, label in knn_list:
class_count[label] += 1
mmax = max(class_count)
for i in xrange(class_total):
if mmax == class_count[i]:
predict_.append(i)
break
return np.array(predict_)
if __name__ == '__main__':
print 'Start read data'
time1 = time.time()
raw_data = pd.read_csv("../data/train.csv", header=0)
time2 = time.time()
print 'Read data cost ', time2 - time1, ' second', '\n'
print raw_data.info()
print raw_data.head()
data = raw_data.values
imgs = data[0::, 1::]
labels = data[::, 0]
train_features, test_features, train_labels, test_labels = train_test_split(
imgs, labels, test_size=0.25, random_state=33)
print train_features.shape
print test_features.shape
print 'Start predicting'
k = KNN()
test_predict = k.predict(test_features, train_features, train_labels)
time3 = time.time()
print 'Predicting cost ', time3 - time2, ' seocnd', '\n'
score = accuracy_score(test_labels, test_predict)
print 'The accuracy score is ', score