機器學習實戰第二章KNN算法源碼

from numpy import *

import operator

from os import listdir

import matplotlib.pyplot as plt


"""程序清單2-1 K近鄰算法"""

def classify0(inX, dataSet, labels, k):

    dataSetSize = dataSet.shape[0]

    diffMat = tile(inX, (dataSetSize, 1)) - dataSet

    sqDiffMat = diffMat**2

    sqDistances = sqDiffMat.sum(axis=1)

    distances = sqDistances**0.5

    sortedDistIndicies = distances.argsort()

    classCount={}

    for i in range(k):

        voteIlabel = labels[sortedDistIndicies[i]]

        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1

    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

    return sortedClassCount[0][0]


"""程序清單2-6 手寫數字識別系統的測試代碼"""

def handwritingClassTest():

    hwLabels = []

    trainingFileList = listdir('trainingDigits')           #load the training set

    m = len(trainingFileList)

    trainingMat = zeros((m,1024))

    for i in range(m):

        fileNameStr = trainingFileList[i]

        fileStr = fileNameStr.split('.')[0]     #take off .txt

        classNumStr = int(fileStr.split('_')[0])

        hwLabels.append(classNumStr)

        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)

    testFileList = listdir('testDigits')        #iterate through the test set

    errorCount = 0.0

    mTest = len(testFileList)

    for i in range(mTest):

        fileNameStr = testFileList[i]

        fileStr = fileNameStr.split('.')[0]     #take off .txt

        classNumStr = int(fileStr.split('_')[0])

        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)

        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)

        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))

        if (classifierResult != classNumStr): errorCount += 1.0

    print("\nthe total number of errors is: %d" % errorCount)

    print("\nthe total error rate is: %f" % (errorCount/float(mTest)))


def img2vector(filename):

    returnVect = zeros((1,1024))

    fr = open(filename)

    for i in range(32):

        lineStr = fr.readline()

        for j in range(32):

            returnVect[0,32*i+j] = int(lineStr[j])

    return returnVect


if __name__ == "__main__":

    testVector = img2vector("testDigits/0_13.txt")

    print(testVector[0, 0:31])

    print(testVector[0, 32:63])

    #handwritingClassTest()


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章