K-近鄰算法(KNN)原理分析和代碼實戰
前言
K-近鄰算法,全稱爲K-nearest neighbor,簡稱KNN。它是一個原理非常簡單,但是計算複雜度比較高的一個分類算法,接下來,我們先從原理出發,再進行源代碼的解析。
源代碼地址:KNN
原理分析
通過計算輸入數據與模型數據的歐幾里得距離,選取前K個距離最短的模型數據,類型出現次數最多的就是輸入數據所屬的類型。
我們來看一下下面這個圖(畫的不好,大家多多擔待)
上圖中,黑色點爲輸入數據,棕色和紅色數據均爲模型數據,我們假設棕色數據屬於1類,紅色數據屬於2類,假設K等於5.
步驟:
- 計算黑色數據與棕色數據和紅色數據之間的距離(歐幾里得距離)
- 找出與黑色數據距離最近的五個數據,如圖中橘黃色線段
- 統計這五個數據所屬的分類。圖中這5個數據中,有3個是紅色數據,屬於2類,2個棕色數據,屬於1類
- 選擇數量最多的類別,即爲輸入數據的類別。圖中5個數據中,紅色數據個數大於棕色數據個數,所以,輸入數據屬於2類。
歐幾里得距離公式:
二維
多維
原理很簡單,接下來咱們分析一下算法優缺點
優點:
- 原理簡單,不涉及複雜的數據理論知識,只有一個歐幾里得距離計算
- 對異常數據不敏感
- 精準度比較高
- 適用於數值型數據和標稱型(就是取值有限,比如0、1或者是、否)數據
缺點:
- 計算量太大,每次輸入數據,都需要與模型中所有數據進行歐幾里得距離計算
- 佔用的空間比較大。
源代碼解析
項目背景:
此項目數據集使用得是《機器學習實戰》一書提供得關於約會對象匹配得數據集,該數據集共有四列數據,前三列是數據的屬性,分別是 行里程數、玩遊戲時間佔比、消耗冰淇淋公升數,最後一列是數據的歸屬類,數據一共分類3類,分別是1、2、3.
數據存儲在txt文件中,不同屬性的數據使用空格進行分割,下圖是數據格式:
一、加載數據
import numpy as np
import operator
def loaddatasets(dataseturl,datatype='train'):
datasetLabel = []
datasetClass = []
with open(dataseturl) as f:
datas = f.readlines()
for data in datas:
dataline = data.strip().split('\t')
datasetLabel.append(dataline[:-1])
datasetClass.append(dataline[-1])
if(type=='train'):
datasetLabel = datasetLabel[:900]
datasetClass = datasetClass[:900]
else:
datasetLabel = datasetLabel[900:]
datasetClass = datasetClass[900:]
return datasetLabel,datasetClass
此方式是加載數據,這裏原數據一共有1000個,由於數據本身就是亂序,所以我們不需要對數據進行亂序處理。我們選取前900個數據爲模型數據,後100個數據作爲測試數據。分別將數據的屬性和數據所屬類存儲到不同的列表中。
二、數據歸一化
## 數據歸一化
def normalized_dataset(dataset):
dataset = np.array(dataset,dtype='float')
max = np.max(dataset,axis=0)
min = np.min(dataset,axis=0)
result = (dataset-min)/(max-min)
return result,max,min
這裏使用的公式是(x-min)/(max-min).爲什麼要進行歸一化呢,從數據集中我們可以看到,這三個屬性的值差別很大。由於KNN算法是通過計算空間距離來判定數據歸屬,那麼,值比較大的就會對計算產生較大的影響,所以,在這裏,我們對數據進行歸一化處理,使其數據在0-1的範圍之間。
三、計算歐幾里得距離
## 計算歐幾里得距離
def calculate_distance(dataset,x):
#此時算出了新數據x與原來每個數據之間的距離
result = np.sqrt(np.sum(np.power((dataset-x),2),axis=1))
#返回值是形狀爲(length,1)的數組
return result
這裏就是計算歐幾里得距離,所使用的公式就是上面圖中所給的公式。
四、進行分類計算
## 進行分類
def KnnClassify(k,inputdata,datasetLabel,datasetClass):
# print(result)
distance = calculate_distance(datasetLabel,inputdata)
sortdistanceindex = np.argsort(distance)
#print("sortdistanceindex",sortdistanceindex)
classcount={ }
for i in range(k):
klist=datasetClass[sortdistanceindex[i]]
classcount[klist] = classcount.get(klist,0)+1
#這裏需要記錄一下,如何對字典中某一屬性進行排序
sortedClassCount = sorted(classcount.items(),key=operator.itemgetter(1),reverse=True)
#print("sortedClassCount:",sortedClassCount)
return sortedClassCount[0][0]
- 前兩行計算輸入數據與模型數據的空間距離,然後對距離數據進行排序。argsort 這裏方法返回的是排序數據原來索引值,這樣做方便我們找到與之對應的原數據。
- 循環K次,找到距離最短的K個原數據的分類
- 創建一個字典,用於統計K個數據所屬分類的數量
- 使用sorted 方法,對字典進行升序排序,這個方法後面會詳細講一下
- 返回K個數據中,類別數量最多的的那個分類,就是輸入數據的分類
五、檢測模型精準度
def TestModelPrecision():
dataseturl = 'datasets/datingTestSet2.txt'
datatestLabel,datatestClass = loaddatasets(dataseturl,datatype='test')
datamodelLabel,datamodelClass = loaddatasets(dataseturl,datatype='train')
datatestLabel,_ ,_ = normalized_dataset(datatestLabel)
datamodelLabel,_,_ = normalized_dataset(datamodelLabel)
#print("normalize:",datasetLabel)
num=0
for i in range(len(datatestClass)):
DataClass = KnnClassify(k=3,inputdata=datatestLabel[i],datasetLabel=datamodelLabel,datasetClass=datamodelClass)
print("當前預測所屬類爲{},實際所屬類爲{}".format(DataClass,datatestClass[i]))
if(int(DataClass)==int(datatestClass[i])):
num+=1
return 100*num/len(datatestClass)
這裏主要用來檢測模型精準度,測試數據使用的就是數據集的後100個,精準率能達到96%,效果還不錯。
六、輸入數據分類
# 輸入數據進行分類
def ClassifyResult():
data1 = input("請輸入飛行里程數:")
data2 = input("請輸入玩遊戲時間佔比:")
data3 = input("請輸入消耗得冰淇淋公升數:")
dataseturl = 'datasets/datingTestSet2.txt'
datamodelLabel,datamodelClass = loaddatasets(dataseturl,datatype='train')
datamodelLabel,max,min = normalized_dataset(datamodelLabel)
inputdata = np.array([data1,data2,data3],dtype='float')
inputdata = (inputdata-min)/(max-min)#處理輸入的數據
DataClass = KnnClassify(k=3,inputdata=inputdata,datasetLabel=datamodelLabel,datasetClass=datamodelClass)
print("輸出結果是:",DataClass)
在這裏,我們可以輸入數據,來判斷數據的歸屬
知識點擴展
如何對字典進行排序?
#字典格式
classcount = {'a':2,'b':33,'c':5}
sorted() #這個方法是python自帶的一個排序方法,返回值是一個按照升序排序的列表,
#我們來分析下面這個
sorted(classcount.items(),key=operator.itemgetter(1),reverse=True)
classcount.items() 是將字典轉化爲元組
key=operator.itemgetter(1) 按照元組的第二個值進行排序
reverse=True 默認升序,這個屬性設置爲true,表示進行降序排列
這裏返回值值列表,列表的數據是元組形式。[('a', 2), ('c', 5), ('b', 33)]
結論
KNN算法原理非常簡單,非常容易理解,並且代碼也很好寫。但是往往,越容易理解,越簡單的東西,背後就會有一些東西被犧牲,比如計算資源和空間容量。而且KNN算法無法得知數據中的基礎結構信息,下一節的決策樹會解決這個問題。