KNN算法之KD樹實現原理

一. KD樹的建立

KD樹算法包括三步,第一步是建樹,第二步是搜索最近鄰,最後一步是預測。

有二維樣本6個,{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)},構建kd樹的具體步驟爲:

1)找到劃分的特徵。6個數據點在x,y維度上的數據方差分別爲5.80,4.47,所以在x軸上方差更大,用第1維特徵建樹。

計算x 軸的方差:

Ex=(2+5+9+4+8+7) / 6 = 35/6
Sx=((2- 35/6)^2 + (5- 35/6) ^ 2 + (9 -35/6) ^2 + (4- 35/6) ^2 +(8 - 35/6) ^ 2 + (7-35/6) ^ 2) /6 = 5.80

計算y軸的方差:

Ey =(3+4+6+7+1+2) / 6 = 23/6
Sy=((3-23/6)^2 + (4-23/6) ^2 + (6-23/6) ^2+ (7-23/6) ^2 +(1-23/6) ^2 + (2 -23/6) ^2)/6 = 4.47

2)確定劃分點(7,2)。根據x維上的值將數據排序,6個數據的中值(所謂中值,即中間大小的值)爲7,所以劃分點的數據是(7,2)。這樣,該節點的分割超平面就是通過(7,2)並垂直於:劃分點維度的直線x=7;

在這裏插入圖片描述

3)確定左子空間和右子空間。 分割超平面x=7將整個空間分爲兩部分:x<=7的部分爲左子空間,包含3個節點={(2,3),(5,4),(4,7)};另一部分爲右子空間,包含2個節點={(9,6),(8,1)}。

劃分左邊三個點:
計算x軸方差:

Ex = (2 + 5 + 4)/3 = 11/3
Sx = ((2 - 11/3) ^ 2 + (5-11/3) ^ 2 + (4 - 11/3) ^2 ) /3= 1.56

計算y 軸方差:

Ey = (3 + 4 + 7) / 3 = 14/3
Sy = ((3- 14/3) ^ 2 +(4 -14/3) ^ 2 +(7 - 14/3) ^ 2)/3 = 2.89

所以y軸的方差比x軸的方差大,按y軸的方向進行分割,先對y軸座標進行排序,找出劃分點. 查找後得到劃分點是(5,4), 並且垂直於y = 4
此時已經把左邊點分開了,y軸比4小的爲於直線下方,比4大的位於直線上方。

在這裏插入圖片描述

4)用同樣的辦法劃分左子樹的節點{(2,3),(5,4),(4,7)}和右子樹的節點{(9,6),(8,1)}。最終得到KD樹。

在這裏插入圖片描述上面的過程用樹結構(排序樹)來表示就是:

在這裏插入圖片描述

二. KD樹搜索最近鄰

生成KD樹以後,就可以去預測測試集裏面的樣本目標點了。對於一個目標點,首先在KD樹裏面找到包含目標點的葉子節點。以目標點爲圓心,以目標點到葉子節點樣本實例的距離爲半徑,得到一個超球體,最近鄰的點一定在這個超球體內部。
用建立的KD樹,來看對點(2, 4.5)找最近鄰的過程。
(1) 首先在KD樹裏面找到包含目標點的葉子節點,從根結點開始遍歷,找到葉子結點(4, 7) :
在這裏插入圖片描述然後以目標點爲圓心,葉子結點(4,7)到目標點的距離爲半徑畫圓,然後回溯到父結點(5,4). 這裏發現(5,4 )到目標點的距離比(4,7)到目標點的距離近,所以我們直接以(2,4,5) 爲圓心,(5,4)爲半徑畫圓。

在這裏插入圖片描述從圖中發現在圈內還有一個點(2,3),那麼現在比較發現該點到目標點的距離比(5,4)到目標點的距離還小,那麼接下來以(2,3)到目標點的距離爲半徑,(2, 4.5)爲圓心畫圓。

在這裏插入圖片描述發現圈內再沒有其它點,搜索路徑回溯完,返回最近鄰點(2,3),最近距離1.5。

三. KD樹預測

有了KD樹搜索最近鄰的辦法,KD樹的預測就很簡單了,在KD樹搜索最近鄰的基礎上,我們選擇到了第一個最近鄰樣本,就把它置爲已選。在第二輪中,我們忽略置爲已選的樣本,重新選擇最近鄰,這樣跑k次,就得到了目標的K個最近鄰,然後根據多數表決法,如果是KNN分類,預測爲K個最近鄰里面有最多類別數的類別。如果是KNN迴歸,用K個最近鄰樣本輸出的平均值作爲迴歸預測值。

四. sklearn 實現KD樹

# 基礎結構.py
#
import numpy as np
from sklearn import linear_model, svm, neighbors, datasets, preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from sklearn.model_selection import cross_val_score

# 關閉報警
import warnings
warnings.filterwarnings("ignore")
np.random.RandomState(0)

# 加載數據
iris = datasets.load_iris()
x, y = iris.data, iris.target

# 劃分訓練集與測試集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3)

# 數據預處理
scaler = preprocessing.StandardScaler().fit(x_train)
x_train = scaler.transform(x_train)
x_test = scaler.transform(x_test)

# 創建模型
clf = neighbors.KNeighborsClassifier(n_neighbors=12,algorithm='kd_tree')
# clf = linear_model.SGDClassifier()
# clf = linear_model.LogisticRegression()
# clf = svm.SVC(kernel='rbf')

# 模型擬合
clf.fit(x_train, y_train)

# 預測
y_pred = clf.predict(x_test)

# 評估
print(accuracy_score(y_test, y_pred))

# f1_score
print(f1_score(y_test, y_pred, average='micro'))

# 分類報告
print(classification_report(y_test, y_pred))

# 混淆矩陣
print(confusion_matrix(y_test, y_pred))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章