k-近鄰算法梳理(從原理到示例)(轉載)

k-近鄰算法是一個有監督的機器學習算法,k-近鄰算法也被稱爲knn算法,可以解決分類問題。也可以解決迴歸問題。本文主要內容整理爲如下:
knn算法的原理、優缺點及參數k取值對算法性能的影響;
使用knn算法處理分類問題的示例;
使用knn算法解決迴歸問題的示例;
使用knn算法進行糖尿病檢測的示例;

1 算法原理

knn算法的核心思想是未標記樣本的類別,由距離其最近的k個鄰居投票來決定。
具體的,假設我們有一個已標記好的數據集。此時有一個未標記的數據樣本,我們的任務是預測出這個數據樣本所屬的類別。knn的原理是,計算待標記樣本和數據集中每個樣本的距離,取距離最近的k個樣本。待標記的樣本所屬類別就由這k個距離最近的樣本投票產生。
假設X_test爲待標記的樣本,X_train爲已標記的數據集,算法原理的僞代碼如下:

  1. 遍歷X_train中的所有樣本,計算每個樣本與X_test的距離,並把距離保存在Distance數組中。
  2. 對Distance數組進行排序,取距離最近的k個點,記爲X_knn。
  3. 在X_knn中統計每個類別的個數,即class0在X_knn中有幾個樣本,class1在X_knn中有幾個樣本等。
  4. 待標記樣本的類別,就是在X_knn中樣本個數最多的那個類別。
    1.1 算法優缺點
    優點:準確性高,對異常值和噪聲有較高的容忍度。
    缺點:計算量較大,對內存的需求也較大。
    1.2 算法參數
    其算法參數是k,參數選擇需要根據數據來決定。
    k值越大,模型的偏差越大,對噪聲數據越不敏感,當k值很大時,可能造成欠擬合;
    k值越小,模型的方差就會越大,當k值太小,就會造成過擬合。
    1.3 變種
    knn算法有一些變種,其中之一是可以增加鄰居的權重。默認情況下,在計算距離時,都是使用相同權重。實際上,可以針對不同的鄰居指定不同的距離權重,如距離越近權重越高。這個可以通過指定算法的weights參數來實現。
    在這裏插入圖片描述

2 示例:使用k-近鄰算法進行分類

from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
# 生成數據
centers = [[-2,2], [2,2], [0,4]]
X, y = make_blobs(n_samples=60, centers=centers,
                 random_state=0, cluster_std=0.60)
# 畫出數據
plt.figure(figsize=(16,10), dpi=144)
c = np.array(centers)
# 畫出樣本
plt.scatter(X[:,0], X[:,1], c=y, s=100, cmap='cool')
# 畫出中心點
plt.scatter(c[:,0], c[:,1], s=100, marker='^',c='orange')
plt.savefig('knn_centers.png')
plt.show()

# 模型訓練
k = 5
clf = KNeighborsClassifier(n_neighbors = k)
clf.fit(X, y)

"""
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=None, n_neighbors=5, p=2,
           weights='uniform')
"""
# 進行預測
X_sample = np.array([[0, 2]])
y_sample = clf.predict(X_sample)
neighbors = clf.kneighbors(X_sample, return_distance=False)


# 畫出示意圖
plt.figure(figsize=(16,10), dpi=144)
c = np.array(centers)
plt.scatter(X[:,0], X[:,1], c=y, s=100, cmap='cool') # 出樣本
plt.scatter(c[:,0], c[:,1], s=100, marker='^',c='k') # 中心點
plt.scatter(X_sample[0][0], X_sample[0][1], marker="x",
           s=100, cmap='cool')      # 待預測的點
for i in neighbors[0]:
    plt.plot([X[i][0], X_sample[0][0]], [X[i][1], X_sample[0][1]],
            'k--', linewidth=0.6)  # 預測點與距離最近的5個樣本的連線
plt.savefig('knn_predict.png')
plt.show()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章