K 近鄰法及其在手寫數字識別的實踐

文章首發於 個人博客

引言

k 近鄰法(k-nearest-neighbor, KNN)是一種基本的分類和迴歸方法。現在只討論其分類方面的應用,它不具備明顯的學習過程,實際上是利用已知的訓練數據集對輸入特徵向量空間進行劃分,並作爲其分類的“模型”。

其中 k 值的選擇、距離的度量及分類決策規則是 k 近鄰模型的三個基本要素。

本文將按照以下提綱進行:

  • k 近鄰法闡述
  • k 近鄰的模型
  • k 近鄰在手寫數字識別上的實戰

k 近鄰法闡述

k 近鄰算法非常容易理解,因爲其本質上就是求距離,這是非常簡單而直觀的度量方法:對於給定的一個訓練數據集,對新的輸入實例 M,在訓練數據集中找到與該新實例 M 最鄰近的 k 個實例,由這 k 個實例按照一定的表決規則進行投票決策最合適的類別,那麼實例 M 就屬於這個類。下面是算法的描述:
CC3A072A-F73B-4A54-A6F9-94EF80316B66

k 近鄰模型

k 近鄰算法本質上是在超空間內劃分區域空間分類的問題,在輸入數據集的特徵空間內,對於每個訓練實例點 xix_i ,距離改點比其他點更近的所有點組成一個區域,叫做單元(cell)。上文說了 k 近鄰模型的三個要素,k 值選擇、距離度量、決策函數,下面一一說明。

k 值選擇

k 值指的是選擇近鄰點的數目,如果 k = 1 則是最近鄰,即是每次由距離新實例最近的訓練點所屬的類別決定待分類實例的類別。

k 值的選擇對於 k 近鄰法的結果可以產生重大影響。

當 k 值較小的時候,那麼預測學習的近似誤差會減少,因爲此時只有距離待分類點較近的訓練實例纔會對於分類預測結果有影響作用,但是缺點是估計誤差會增大,因爲預測結果會對近鄰的實例點非常敏感,如果近鄰的實例多數都是噪聲點,那麼就很容易導致預測出錯。即是說,k 值的減少意味着模型變得複雜,容易發生過擬合。

當 k 值較大的時候,就相當於用較大鄰域中的訓練實例進行預測。其優點是可以減少學習的估計誤差。但缺點是學習的近似誤差會增大,這時與輸入實例較遠的(不相似的)訓練實例也會對預測起作用,使預測發生錯誤。k 值的增大就意味着整體的模型變得簡單。

如果 k = N,那麼無論輸入實例是什麼,都將簡單地預測它屬於在訓練實例中最多的類。這時,模型過於簡單,完全忽略訓練實例中的大量有用信息,是不可取的。

在應用中,k 值一般取一個比較小的數值。通常採用交叉驗證法來選取最優的 k 值。

距離度量

在實數域中,數的大小和兩個數之間的距離是通過絕對值來度量的。在解析幾何中,向量的大小和兩個向量之差的大小是“長度”和“距離”的概念來度量的。爲了對矩陣運算進行數值分析,我們需要對向量 和矩陣的“大小”引進某種度量。而範數是絕對值概念的自然推廣。

特徵空間中兩個實例點的距離是其相似程度的反映,k 近鄰空間選用歐式距離及更一般的 LpL_p 距離。

設特徵空間 X 是 n 維實數向量空間 RnR^nxi,xjX,xi=(xi1,xi2,,xin),xj=(xj1,xj2,,xjn)x_i,x_j \in \mathcal{X}, x_i = (x_i^1,x_i^2,\cdots,x_i^n), x_j = (x_j^1,x_j^2,\cdots,x_j^n),則 xi,xjx_i, x_jLpL_p距離定義爲:
Lp(xi,xj)=(l=1nxilxjlp)1p L_p(x_i,x_j) = \left( \sum_{l=1}^n |x_i^l-x_j^l|^p \right) ^{\frac{1}{p}}

這裏 p 要不小於1,當 p = 2時,成爲歐氏距離;
當 p = 1 時,稱爲曼哈頓距離;
當 p = \infty 時,它是各個座標距離的最大值。

分類決策

k 近鄰法中的分類決策規則往往是多數投票表決,即由輸入實例的 k 個鄰近的訓練實例中的多數類決定輸入實例的類。

多數表決規則(majorityvotingrule)有如下解釋:如果分類的損失函數爲 0-1 損失函數,分類函數爲:
f:Rn{c1,c2,,cK} f:R^n \to \{c_1,c_2,\cdots, c_K\}

那麼對給定的實例 xXx\in X,其最近鄰的 k 個訓練實例點構成集合 Nk(xN_k(x。如果涵蓋Nk(x)N_k(x)的區域的類別是,那麼誤分類率是:
1kxiNk(x)I(yicj)=11kxiNk(x)I(yi=cj) \frac{1}{k} \sum_{x_i\in N_k(x)} I (y_i\not=c_j) = 1 - \frac{1}{k} \sum_{x_i\in N_k(x)} I (y_i = c_j)

要使誤分類率最小即經驗風險最小,就要使xiNk(x)I(yi=cj)\sum_{x_i\in N_k(x)} I (y_i = c_j)最大,所以多數表決規則等價於經驗風險最小化。

同時多數表決可以加權表決,可以一定程度提高表決結果的準確性。

k 近鄰在手寫數字識別上的實戰

數據集的讀取和解析和樸素貝葉斯法識別手寫數字的原理一樣,這裏不再贅述。

代碼實現算法上,這裏先採用線性暴搜的方法,效率上明顯是非常低的,耗時也比樸素貝葉斯慢的多,但是準確率卻非常高,目前表決數爲 k=3 的情況下且不加權的預測準確率可以達到 94% 以上。

訓練預測結果如下:
597C98CC-BFA6-4998-9E03-A43C6BAF6620

可以看出,測試 2100 個圖片,用了1218秒,20多分鐘,效率非常慢。
A535D08C-349E-46B2-8147-E59F88076FB1

但是準確率異常高,且比較穩定。

總結

  • 更高效率的 k 近鄰尋找方法是 k-d樹(k-dimensional樹的簡稱),這是一種分割 k 維數據空間的數據結構,主要應用於多維空間關鍵數據的搜索。

  • 可以對 k 近鄰進行加權表決,對於預測準確率應該也會有所提升。

NEXT

下一次將實踐以上的兩點總結,看看具體的表現如何吧。


附KNN 算法的線性暴搜實現如下:

# -*- coding: utf-8 -*
import time
import matplotlib.pyplot as plt
import testLibrary as tl
import collections
import numpy as np

# 距離計算
def calc_dis(train_image,test_image):
    dist=np.linalg.norm(train_image-test_image)
    return dist


# 確定待分類實例的 k 近鄰
def find_labels(k,train_images,train_labels,test_image):
    all_dis = []
    labels=collections.defaultdict(int)
    for i in range(len(train_images)):
        dis = np.linalg.norm(train_images[i]-test_image)
        all_dis.append(dis)
    sorted_dis = np.argsort(all_dis)
    count = 0
    while count < k:
        labels[train_labels[sorted_dis[count]]]+=1
        count += 1
    return labels


# 結合訓練數據集,對所有待分類實例進行 k 近鄰分類預測
def knn_all(k,train_images,train_labels,test_images, test_labels):
    print("start knn_all!")
    res=[]
    right = 0
    accuracy = []
    count=0
    for i in range(2100):
        labels=find_labels(k,train_images,train_labels,test_images[i])
        res.append(max(labels))
        print("Picture %d has been predicted! real is %d predicted is %d"%(count, test_labels[i], max(labels)))
        count+=1
        if max(labels) == test_labels[i]:
            right+=1
        if (i+1) % 70 == 0:
            accuracy.append(float(right)/(i+1))
    return res, accuracy


# 總的預測準確率計算
def calc_precision(res,test_labels):
    f_res_open=open("res.txt","a+")
    precision=0
    for i in range(len(res)):
        f_res_open.write("res:"+str(res[i])+"\n")
        f_res_open.write("test:"+str(test_labels[i])+"\n")
        if res[i]==test_labels[i]:
            precision+=1
    return precision/len(res)


if __name__ == '__main__':
    print('Start process train data')
    time_0 = time.time()
    # tl.get_train_set()

    print('Start process test data')
    time_t = time.time()
    # tl.get_test_set()

    # 讀取訓練數據集和測試數據集的方法和樸素貝葉斯方法一致
    print ('Start read train data')
    time_1 = time.time()
    data_map, labels = tl.loadCSVfile("data.csv")
    print(data_map.shape, labels.shape)
    time_2 = time.time()
    print('read data train cost ', time_2 - time_1, ' seconds', '\n')

    print('Start read predict data')
    time_3 = time.time()
    test_data_map, test_labels = tl.loadCSVfile("dataTest.csv")
    print(test_data_map.shape, test_data_map.shape)
    time_4 = time.time()
    print('read predict data cost ', time_4 - time_3, ' seconds', '\n')

    print('Start predicting data')
    time_5 = time.time()
    res, accuracy = knn_all(3, data_map, labels, test_data_map, test_labels)
    score = calc_precision(res, test_labels)
    time_6 = time.time()
    print('read predict data cost ', time_6 - time_5, ' seconds', '\n')

    new_ticks = np.linspace(1, 30, 30)
    plt.xticks(new_ticks)
    plt.ylim(ymin=0.5, ymax = 1)
    plt.plot(new_ticks, accuracy, 'o-', color='g')
    plt.xlabel("x -- 1:70")
    plt.ylabel("y")
    plt.title(u"預測準確率")
    plt.show()

    print("The accuracy rate is ", score)
    print("All data processing cost %s seconds" % (time_6 - time_0))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章