kNN算法識別手寫數字(代碼筆記)

k-近鄰算法,屬於有監督分類算法。

思想:利用輸入數據特徵值和訓練樣本數據特徵值之間的距離分類,挑出距離最小的k個訓練樣本的類別頻率,作爲預測的分類估計。

'''
k-近鄰算法是基於實例的學習
1 使用時要保存全部的數據集,佔存儲空間
2 要對每個訓練數據計算距離值,實際使用時非常耗時
'''
import numpy as np
import operator

def classify0(x, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = np.tile(x, (dataSetSize,1)) - dataSet
    sqDiff = diffMat**2
    sqDist = sqDiff.sum(axis=1)
    distances = sqDist**0.5  # 一行數據的平方根
    sortedDistInd = distances.argsort()  # 向量元素從小到大對應的索引號
    classCount = {}
    for i in range(k):  # 前k個,也就是最近的k個; 統計類出現的頻率
        vLabel = labels[sortedDistInd[i]]  
        classCount[vLabel] = classCount.get(vLabel,0)+1
    sortedClassCount = sorted(classCount.items(), # 轉成dict_items:[(key1,cnt1),(key2,cnt2),..]
                       key=operator.itemgetter(1), # 排序,依據tuple第二個元素;reverse,由大到小
                       reverse=True)
    return sortedClassCount[0][0]
    
def img2vec(filename):  # 32x32的矩陣數據轉成向量
    vec = np.zeros((1,1024))
    fr = open(filename)  # (如果是txt文件的話)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            vec[0, 32*i+j] = int(lineStr[j])
    return vec

def handwritingClassify():
    trainLab = []
    trainFileList = listdir('trainingDigits')  # 訓練數據目錄
    m = len(trainFileList)
    trainMat = zeros((m,1024))  # 訓練數據存成一個矩陣
    for i in range(m):
        filenameStr = trainFileList[i]
        fileStr = filenameStr.split('.')[0]
        classStr = int(fileStr.split('_')[0])
        trainLab.append(classStr)
        trainMat[i,:] = img2vec('trainingDigits/%s' % filenameStr)
    #------------------------ 測試數據 -------------------------
    errorCount = 0.0
    testFileList = listdir('testDigits')  # 測試數據目錄
    n = len(testFileList)
    for i in range(n):
        filenameStr = testFileList[i]
        fileStr = filenameStr.split('.')[0]
        classStr = int(fileStr.split('_')[0])
        vecTest = img2vec('testDigits/%s' % filenameStr)
        classTest = classify0(vecTest, trainMat, trainLab, 3)  # 測試數據的直接分類
        print("the classifier predicts : %d, the real is : %d" % (classTest,classStr))
        if(classTest!=classStr):
            error += 1.0
    print("\n the total numbers of errors is: %d" % errorCount)
    print("\n the total error rate is: %d" % (error/float(n)))


發佈了71 篇原創文章 · 獲贊 98 · 訪問量 42萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章