KNN模型筆記


一、KNN模型

KNN(K近鄰)模型,不會預先生成一個分類或預測模型,用於新樣本的預測,而是將模型的構建與未知數據的預測同時進行。
該算法對數據的分佈特徵沒有任何要求。

1 核心思想

比較已知y值的樣本與未知y值樣本的相似度,然後尋找最相似的k個樣本用作未知樣本的預測。
算法主要任務:

  • 確定最近鄰的個數k值;
  • 用於度量樣本間相似性的指標。

2 k值的選擇

k值的影響:

  • k值過於偏小,可能會導致模型的過擬合;
  • 反之,又可能會使模型進入欠擬合狀態

爲了獲得最佳的k值,可以考慮三種解決方案

  • 第一種,設置k近鄰樣本的投票權重。通常可以將權重設置爲距離的倒數;
  • 第二種,採用多重交叉驗證法,該方法是目前比較流行的方案,其核心就是將k取不同的值,然後在每種值下執行m重的交叉驗證,最後選出平均誤差最小的k值;
  • 第三種,結合前兩種方法,選出理想的k值。

3 相似度的度量方法

  • 歐氏距離
  • 曼哈頓距離
  • 餘弦相似度
  • 傑卡德相似係數

3.1 歐氏距離

該距離度量的是兩點之間的直線距離。公式如下(針對點A(x1,x2,...,xn)x_1,x_2,...,x_n)、B(y1,y2,...,yny_1,y_2,...,y_n)):

dA,B=(y1x1)2+(y2x2)2+...+(ynxn)2d_{A,B}=\sqrt{(y_1-x_1)^2+(y_2-x_2)^2+...+(y_n-x_n)^2}

3.2 曼哈頓距離

該距離也稱爲“曼哈頓街區距離”,度量的是兩點在軸上的相對距離總和。公式如下(針對點A(x1,x2,...,xn)x_1,x_2,...,x_n)、B(y1,y2,...,yny_1,y_2,...,y_n)):

dA,B=y1x1+y2x2+...+ynxnd_{A,B}=|y_1-x_1|+|y_2-x_2|+...+|y_n-x_n|

3.3 餘弦相似度

該相似度其實就是計算兩點所構成向量夾角的餘弦值,夾角越小,則餘弦值越接近於1,進而能夠說明兩點之間越相似。公式如下(針對點A(x1,x2,...,xn)x_1,x_2,...,x_n)、B(y1,y2,...,yny_1,y_2,...,y_n)):

SimilarityA,B=cosθ=ABABSimilarity_{A,B}=\cos \theta=\frac{\vec{A}·\vec{B}}{\mid\mid \vec{A}\mid \mid\mid\mid\vec{B}\mid\mid}
其中,點·代表兩個向量之間的內積,符號‖‖代表向量的模,即l2正則。

3.4 傑卡德相似係數

該相似係數與餘弦相似度經常被用於推薦算法,計算用戶之間的相似性。 公式如下:
J(A,B)=ABABJ(A,B)=\frac{\mid A\bigcap B\mid}{\mid A \bigcup B\mid}
其中,|A∩B|表示兩個用戶所購買相同商品的數量,|A∪B|代表兩個用戶購買所有產品的數量。

使用距離方法來度量樣本間的相似性時,必須注意兩點:

  • 一個是所有變量的數值化;
  • 另一個是防止數值變量的量綱影響,必須採用數據的標準化方法對其歸一化,使得所有變量的數值具有可比性。

4 近鄰樣本的搜尋方法

近鄰樣本的搜尋方法:

  • 暴力搜尋法:需要全表掃描,只能適合小樣本的數據集
  • KD樹搜尋法:不需要全表掃描
  • 球樹搜尋法:不需要全表掃描

4.1 KD樹搜尋法

K-Dimension Tree,K表示訓練集中包含的變量個數。其最大的搜尋特點:先利用所有已知類別的樣本點構造一棵樹模型,然後將未知類別的測試集應用在樹模型上,實現最終的預測功能。
KD樹搜尋法的兩個重要步驟

  • KD樹的構造
  • KD樹的搜尋

缺點:該方法在搜尋分佈不均勻的數據集時,效率會下降很多。

4.2 球樹搜尋法

球樹搜尋法能夠解決KD樹的缺陷,是因爲球樹將KD樹中的超矩形體換成了超球體,沒有了“角”,就不容易產生模棱兩可的區域。
優缺點:與KD樹的思想非常相似,所不同的是,球樹的最優搜尋路徑複雜度提高了,但是可以避免很多無謂樣本點的搜尋

球樹搜尋法的兩個重要步驟

  • 球樹的構造
  • 球樹的搜尋

5 KNN模型實例

Python中的sklearn的子模塊neighbors中有關KNN算法的類:KNeighborsClassifier類(分類)和KNeighborsRegressor類(預測)。

KNeighborsClassifier(
n_neighbors=5,
weights=‘uniform’, # ‘uniform’,表示所有近鄰樣本的投票權重一樣;如果爲’distance’,則表示投票權重與距離成反比
algorithm=‘auto’, #‘ball_tree’,則表示使用球樹搜尋法; ‘kd_tree’,則表示使用KD樹搜尋法; ‘brute’,則表示使用暴力搜尋法
leaf_size=30, # 用於指定球樹或KD樹葉子節點所包含的最小樣本量
p=2,
metric=‘minkowski’, # 用於指定距離的度量指標
metric_params=None,
n_jobs=None,
**kwargs,
)

以下以分類問題爲例:

# 導入第三方庫
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score # 交叉驗證

# 讀入數據
kdata = pd.read_excel(r'Knowledge.xlsx')
# --------------------構造訓練集和測試集------------------------
x_columns = kdata.columns[:-1]
y_column = kdata.columns[-1]
X_train,X_test,y_train,y_test = train_test_split(kdata[x_columns],
                                                 kdata[y_column],
                                                 test_size=0.25,
                                                 random_state=111)
#-------------------使用十折交叉驗證尋找最優k值-----------------
# 設置待測試的不同k值
K = np.arange(1,np.int(np.log2(kdata.shape[0])))
accuracy = [] # 用於存儲不同k值10折後的模型平均準確率
for k in K:
    # 使用十折交叉驗證比對不同k值的預測準確率
    cv_result = cross_val_score(KNeighborsClassifier(n_neighbors=k,
                                                     weights='distance'),
                               X_train,y_train,cv=10,scoring='accuracy')
    accuracy.append(cv_result.mean())
# -------------------可視化準確率結果-----------------------------
# 挑選出準確率最大值對應的下標
k_max_index = np.array(accuracy).argmax()
# 中文和負號的正常顯示
plt.rcParams['font.sans-serif']=['Microsoft YaHei']
plt.rcParams['axes.unicode_minus']= False
plt.plot(K,accuracy)
plt.scatter(K,accuracy)
plt.text(K[k_max_index],accuracy[k_max_index],'最佳k值爲{}'.format(K[k_max_index]))
plt.show()

在這裏插入圖片描述

# ---------------- 根據最佳k值 構建模型-----------------------
model = KNeighborsClassifier(n_neighbors=K[k_max_index],weights='distance')
model.fit(X_train,y_train)
y_pre = model.predict(X_test)
# 構建混淆矩陣
cross_matrix = pd.crosstab(y_pre,y_test)
# 將混淆矩陣構造成數據框,並加上字段名和行名稱,用於行或列的含義說明
cross_matrix = pd.DataFrame(cross_matrix)
# 繪製熱力圖
sns.heatmap(cross_matrix, annot = True,cmap = 'GnBu')
# 添加x軸和y軸的標籤
plt.xlabel(' 真實值')
plt.ylabel(' 預測值')
# 圖形顯示
plt.show()

在這裏插入圖片描述

# ----------------- 模型評估報告 -------------------------
print(metrics.classification_report(y_test,y_pre))
類別 precision recall f1-score support
High 1.00 0.96 0.98 27
Low 0.82 1.00 0.90 33
Middle 0.96 0.86 0.91 29
Very Low 1.00 0.75 0.86 12
accuracy 0.92 101
macro avg 0.95 0.89 0.91 101
weighted avg 0.93 0.92 0.92 101
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章