KNN--K近鄰算法

一、KNN算法概述

KNN算法(k-NearestNeighbor),即K最近鄰,是一種監督學習(Supervised Learning)算法字面意思是與計算點最近的k個相鄰點,意思是說每個樣本點都可以用與之最近的k個相鄰的點來表示。

KNN是一種分類算法(Classification Algorithm),它所謂的學習過程是基於輸入的實例的,類似於懶惰學習(Lazy Learning),即KNN沒有明顯的學習過程,也就是說沒有訓練階段,數據集事先已有了分類和特徵值,待收到新樣本後直接進行處理,與急切學習(eager learning)相對應。

KNN算法原理比較簡單,但是可以解決很多比較經典的問題。KNN的原理大致是:計算待分類的數據點和訓練數據集(Training Dataset)中的每個數據點的歐式距離(Euclidean Distance,又稱歐幾里得距離),然後選取最近的k個數據點,並對這k個點的類別進行頻率統計,將這k個點中出現頻率最大的類別作爲待分類點的距離。

從百度百科中找到了這張圖加深理解。

W1、W2、W3是訓練數據集,圖上這些點分別屬於W1、W2、W3這三類,Xu是待分類點,利用KNN算法測得Xu的所在類別大致流程是這樣,輸入訓練數據和Xu,然後計算Xu與每個點的歐式距離,選取最近的5個點,如圖,箭頭所指的5個點,這五個點中四個是紅色,一個是綠色的,所以Xu應該屬於紅色這個類別,即W1。

具體算法描述

1. 計算待測點到每個training dataset中的點的距離並保存,時間複雜度O(n),空間複雜度O(n)。

2. 選取距離最小的k個點,時間複雜度O(logn),空間複雜度O(k)。

3. 對這k個點進行統計,得到每種類別出現的頻率。O(k)

4. 選取頻率最大的類別作爲待分類點的類別。O(k)

二、關於k值的選取

k值的含義是所選取的臨近點數。

k值不能過大也不能過小,不然會嚴重影響預測結果的準確度。k值過小時異常值或噪聲值的影響會過大,例如 k=1 時,距離待測點最近的那個點直接決定待測點的類別,顯然是不合理的,也是不準確的。當k過大,距離待測的較遠的點也會對待測點的計算產生影響,就失去了k最近鄰的意義,例如 k=n 時,即選取數據集中全部的點進行計算,這樣待測點的類別完全取決於訓練集中那個類別的樣本最多,顯然也是不合理的。

k值一般不會超過20,上限是n的平方根,即訓練集越大,k值應該越大。具體的k值選取應該利用訓練集去檢驗得到一個準確度最高的k值。

三、關於距離的計算

一般情況是使用歐式距離,即歐幾里得距離。不過也可以用曼哈頓距離或角度弧度代替,具體問題應該具體分析。

定義

歐幾里得度量(euclidean metric)(也稱歐氏距離)是一個通常採用的距離定義,指在m維空間中兩個點之間的真實距離,或者向量的自然長度(即該點到原點的距離)。在二維和三維空間中的歐氏距離就是兩點之間的實際距離。 [1] 

計算公式

二維空間的公式

兩點之間的歐式距離

 

三維空間的公式

n維空間的公式

四、總結

1. KNN算法比較簡單容易理解,處理一些典型問題時的表現也非常好,但是它的時間複雜度和空間複雜度都非常高。

2. KNN對異常值不是很敏感,因爲是根據多個近鄰平均效果得到最終的分類結果。

3. KNN對於隨機分佈的數據集分類效果較差,對於類內間距小,類間間距大的數據集分類效果好,而且對於邊界不規則的數據效果好於線性分類器。

4. KNN對於樣本不均衡的數據效果不好,需要進行改進。改進的方法時對k個近鄰數據賦予權重,比如距離測試樣本越近,權重越大。

5. KNN很耗時,時間複雜度爲O(n),一般適用於樣本數較少的數據集,當數據量大時,可以將數據以樹的形式呈現,能提高速度,常用的有kd-tree和ball-tree。

(前兩點來自《機器學習實戰》這本書,後三點來自大佬的博客:https://www.cnblogs.com/jyroy/p/9427977.html#idx_6,卑微的我🙁)

五、Sklearn代碼實現

# -*- coding: utf-8 -*-
# @Author: yezhipeng
# @Date:   2019-08-04 16:26:19
# @Last Modified by:   yezhipeng
# @Last Modified time: 2019-08-04 16:56:43
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

iris = datasets.load_iris()
iris_x = iris.data
iris_y = iris.target

# print(iris_x)
# print(iris_y)

x_train,x_test,y_train,y_test = train_test_split(iris_x,iris_y,test_size = 0.3)

# 默認打亂數據,避免產生不必要的誤差
# print(y_train)

# 實例化一個k近鄰分類器的對象
knn = KNeighborsClassifier()
# 擬合數據
knn.fit(x_train,y_train)
print(knn.predict(x_test))
print(y_test)

(代碼來自莫煩python的“Scikit-learn (sklearn) 優雅地學會機器學習”視頻教程中的Knn分類iris花的例子,鏈接在這裏:https://www.bilibili.com/video/av17003173?from=search&seid=14492162154020710353)

後續會補充《機器學習實戰》上面的案例,如有錯誤評論聯繫更正。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章