學習筆記(一)k-近鄰算法(KNN)

學習筆記(一)k-近鄰算法(KNN)

終於找到《機器學習實戰》這本書了,在此記錄一些總結,便於回顧。

原理

  1. KNN的工作原理是:
    存在一個樣本數據集合,也稱作訓練樣本集,並且樣本集中每個數據都存在標籤,即我們知道樣本集中每一數據與所屬分類的對應關係。輸入沒有標籤的新數據後,將新數據的每個特徵與樣本集中數據對應的特徵進行比較,然後算法提取樣本集中特徵最相似數據(最近鄰)的分類標籤。一般來說,我們只選擇樣本數據集中前k個最相似的數據,這就是k-近鄰算法中k的出處,通常k是不大於20的整數。最後,選擇k個最相似數據中出現次數最多的分類,作爲新數據的分類。
  2. 我的理解:是在訓練樣本中找出與新數據(測試樣本)距離最近的K個樣本,這K個樣本中哪個類別的樣本數最多,新數據就屬於哪一類。
  3. 距離選擇
    歐氏距離:d(x,y)=k=1n(xkyk)2d(x,y)= \sqrt{\quad \sum_{k=1}^{n}{(x_k - y_k)^2}}
    曼哈頓距離:d(x,y)=k=1nxkykd(x,y)= \quad \sum_{k=1}^{n}{|x_k - y_k|}

k-近鄰算法的一般流程

(1) 收集數據:可以使用任何方法。
(2) 準備數據:距離計算所需要的數值,最好是結構化的數據格式。
(3) 分析數據:可以使用任何方法。
(4) 訓練算法:此步驟不適用於k-近鄰算法。
(5) 測試算法:計算錯誤率。
(6) 使用算法:首先需要輸入樣本數據和結構化的輸出結果,然後運行k-近鄰算法判定輸
入數據分別屬於哪個分類,最後應用對計算出的分類執行後續的處理。

k-近鄰算法python實現

# -*- coding: utf-8 -*-
import numpy as np 
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

class KNN():  
    def fit(self,x_train,y_train):
        '''
        x_train數據格式:每一列表示一個屬性,每一行表示一個樣本
        y_train數據格式:一維數組,表示標籤,與X_train相對應
        '''
        self.x_train = x_train  
        self.y_train = y_train

    def predict(self,x_test,k = 1):
        self.k = k
        #計算歐式距離
        distance = (np.sum((self.x_train - x_test) ** 2,1)) ** 0.5
        sortindex = np.argsort(distance)
        sortindex_k = sortindex[:self.k]	#距離最近的k個樣本索引
        lable_k = self.y_train[sortindex_k]   #選擇距離最近的前k個標籤
        labelCount = {}
        for i in lable_k:
            if i in labelCount:
                labelCount[i] += 1
            else:
                labelCount[i] = 1
        result = sorted(labelCount.items(), key=lambda k:k[1], reverse=True)
        return result[0][0]


if __name__ == '__main__':

    x_train = mnist.train.images
    y_train = mnist.train.labels
    x_test = mnist.test.images
    y_test = mnist.test.labels

    knn = KNN()
    #由於訓練樣本較多,可以考慮選擇部分樣本作爲輸入
    knn.fit(x_train,y_train)
    y_predict = []
    #選擇了測試集前10個樣本做測試
    for i in range(10): 
        y_predict.append(knn.predict(x_test[i],2))

    print('預測值:',y_predict)
    print('實際結果:',y_test[:10])

小結

k-近鄰算法是分類數據最簡單最有效的算法,但是當訓練樣本數量非常大時,必定會耗費非常多的計算機資源,由於必須對數據集中的每個數據計算距離值,實際使用時可能非常耗時。

爲了緩解這些缺點,可以嘗試將原始數據進行降維,減少計算量,另外,當樣本數量比較多,而類別較少時,可以適當較少的選擇樣本進行訓練。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章