機器學習輕鬆入門——KNN算法的PYTHON語言實現

KNN算法,也稱K近鄰算法,是一種監督學習的分類算法。

本篇文章主要由以下幾個方面構成:

  1. KNN算法的原理及僞代碼
  2. KNN算法的優缺點
  3. KNN算法實現手寫數字識別系統

1.KNN算法的原理及僞代碼

KNN算法,即在已知訓練集數據所對應標籤的情況下,去預測測試集數據所對應的標籤,其算法核心就是要找到其訓練集數據與其標籤之間的對應關係。

僞代碼:

  • 計算當前測試數據(1個)與所有訓練集數據(N個)的距離
  • 將有距離升序排序,篩選出與當前測試數據距離最短的前K個訓練集數據(算法難點:K值如何選定,一般<20
  • 依次找到對應的標籤,統計標籤的出現頻次
  • 返回出現頻次最高的標籤作爲當前測試數據的預測分類結果

2.KNN算法的優缺點

優點:簡單易懂,精度高,對異常值不敏感

缺點:計算複雜度高,空間複雜度高

3.KNN算法實現手寫數字識別系統

3.1該項目的實現過程及架構

項目概述:

 

該項目主要由三大塊組成:

  • 前期數據預處理:將32行*32列的圖像數據轉換成1行*1024列的向量數據
  • 構建KNN分類器:實現對當前測試集數據(單個)的分類結果預測
  • (主程序)測試KNN分類器:對所有測試集數據進行分類,並統計其錯誤率

3.2該項目的PYTHON代碼

用到的Python模塊:numpy、os模塊中listdir函數、operator模塊

首先導入模塊:

import numpy as np
import operator
from os import listdir #從os模塊中導入listdir函數,實現讀取文件夾下的所有文件名功能

然後進行數據預處理:

#將32*32轉換成1*1024
def img2vector(filename):
    # 創建向量
    returnVect = np.zeros((1, 1024))
    # 打開數據文件,讀取每行內容
    fr = open(filename)
    for i in range(32):
        # 讀取每一行
        lineStr = fr.readline()
        # 將每行前 32 字符轉成 int 存入向量
        for j in range(32):
            returnVect[0, 32*i+j] = int(lineStr[j])
    return returnVect

接着構建KNN分類器:

#KNN分類器
import operator
def classify0(inX, dataSet, labels, k):

    """
    參數: 
    - inX: 需要預測分類的當前測試集數據
    - dataSet: 輸入的訓練集數據
    - labels: 訓練集數據的標籤向量
    - k: 用於選擇最近鄰居的數目
    """

    # 獲取訓練數據集的行數
    dataSetSize = dataSet.shape[0]

    # 矩陣運算,計算測試數據與每個樣本數據對應數據項的差值
    diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet

    # sqDistances 上一步驟結果平方和
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)

    # 取平方根,得到距離向量
    distances = sqDistances**0.5

    # 按照距離從低到高排序
    sortedDistIndicies = distances.argsort()
    
    # 依次取出最近的樣本數據
    for i in range(k):
        # 根據索引,找到該樣本數據所屬的標籤
        voteIlabel = labels[sortedDistIndicies[i]]
        # 建立一個字典,用於存放標籤出現的頻次,統計標籤出現的頻次
        classCount = {}
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1

    # 對標籤出現的頻次進行排序,從高到低
    sortedClassCount = sorted(
        classCount.items(), key=operator.itemgetter(1), reverse=True)

    # 返回出現頻次最高的類別
    return sortedClassCount[0][0]

注意:

  • shape函數是numpy.core.fromnumeric中的函數,它的功能是查看矩陣或者數組的維數

    舉例說明:建立一個2×4的單位矩陣e, e.shape爲(2,4),表示2行4列,第一維的長度爲2,第二維的長度爲4。e.shape[0]爲2,e.shape[1]爲4。

  • tile函數的功能是對一個矩陣進行重複,得到新的矩陣,例如:b=np.tile(a,(3,1)),得到b=[a,a,a](將a重複3行)

  • argsort函數是返回一個數組升序排列的索引值,例如:x = np.array([3, 1, 2]),np.argsort(x),返回array([1, 2, 0])

  • operator.itemgetter()返回的是一個函數

       operator.itemgetter(1)按照第二個元素的次序對元組進行排序,reverse=True是逆序,即按照從大到小的順序排列

       所以 sorted這裏的意思是:

       classCount.items()將classCount字典分解爲元組列表

       即由變成

       並且按第二個元素進行從大到小的排列

       最後return sortedClassCount[0][0],就是返回sortedClassCount的第一行第一列,即頻數最高的那個對應的分類標籤

 

最後是該項目的主程序,對所有測試集數據進行分類預測:

def handwritingClassTest():
    # 樣本數據的類標籤列表
    hwLabels = []

    # 樣本數據文件列表
    trainingFileList = listdir(r'C:\Users\lenovo\Desktop\PYTHON\識別手寫數字\trainingDigits')
    m = len(trainingFileList)
    # 初始化樣本數據矩陣(M*1024)
    trainingMat = np.zeros((m, 1024))

    # 依次讀取所有樣本數據到數據矩陣
    # fileNamestr裏存放的是當前訓練集文件名,提取文件名中的數字
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        # 將樣本數據存入矩陣
        trainingMat[i, :] = img2vector(r'C:\Users\lenovo\Desktop\PYTHON\識別手寫數字\trainingDigits/%s' % fileNameStr)


    # 循環讀取測試數據
    testFileList = listdir(r'C:\Users\lenovo\Desktop\PYTHON\識別手寫數字\testDigits')
    mTest = len(testFileList)

    # 初始化錯誤率
    errorCount = 0.0
    
    # 循環測試每個測試數據文件
    for i in range(mTest):
        # fileNamestr裏存放的是當前測試集文件名,提取文件名中的數字
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])

        # 提取數據向量
        vectorUnderTest = img2vector(r'C:\Users\lenovo\Desktop\PYTHON\識別手寫數字\testDigits/%s' % fileNameStr)

        # 對數據文件進行分類,K值選取爲3
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)

        # 打印 K 近鄰算法分類結果和真實的分類
        print("測試樣本 %d, 分類器預測: %d, 真實類別: %d" %
              (i+1, classifierResult, classNumStr))

        # 判斷K 近鄰算法結果是否準確
        if (classifierResult != classNumStr):
            errorCount += 1.0

    # 打印錯誤率
    print("\n錯誤分類計數: %d" % errorCount)
    print("\n錯誤分類比例: %f" % (errorCount/float(mTest)))

注意:

  • split函數爲根據分隔符對字符串進行切片,split(‘分隔符’,num),num默認爲全分割,num=1則只分割一次。例如:a='1_2.txt',b=a.split('.')=['1_2','txt']
  • append() 方法用於在列表末尾添加新的對象

 

最後的最後不要忘了運行主程序哦:

handwritingClassTest()

修改K值,分類結果的準確率不同:

K=2,錯誤計數13

K=3,錯誤計數10

K=4,錯誤計數11

K=5,錯誤計數17

K=6,錯誤計數17

最佳K值爲3

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章