K-nn手寫數字識別--Python版

模式識別的實驗作業,弄了一個晚上終於在第二天中午弄明白了!

簡單來說,k-nn就是通過計算訓練集和 一個測試數據之間的歐式距離,然後將計算結果按照從小到大來排序,找出最小的k個數據,分析k個數據中哪種情況出現的頻率最多,那麼這個測試數據就屬於這一類

  1. 思路

  2. 讀入數據,假設100個訓練數據,將訓練數據轉換爲100*1024的二維數組,然後循環讀入測試數據,計算測試數據和100個訓練數據間的歐式距離:
    歐式距離公式
    x1-xn爲單個訓練數據的所有元素,y1-yn爲測試數據的所有元素

    這樣就得到一個數組,包含所有訓練數據和測試數據的歐式距離,將距離從小到大進行排序。

3. 結果

找出k個最近的距離,看哪個數字出現的頻率最多,那麼這個測試數據大概率爲這個數字

#解壓文件
def JY():
    path="/Users/fanjialiang2401/PycharmProjects/模式識別/digits.zip"
    newpath="/Users/fanjialiang2401/PycharmProjects/模式識別/"
    f=zipfile.ZipFile(path,'r')
    for  file in f.namelist():
        f.extract(file,newpath)
    print("success!")
#     將32*32矩陣轉換爲一個長爲1024的一位數字
def toVerctor(filename):
    returnVect=np.zeros((1,1024))
    fr=open(filename)
    for i in range(32):
        linestr=fr.readline();
        for j in range(32):
            returnVect[0,32*i+j]=int(linestr[j])
    return returnVect;
# 測試 trainlist爲訓練集所有數據,testdata爲測試數據 classLable爲
def Classfiy(Trainlist,testdata,classLable,k):
    listSize=len(Trainlist)
    diffs=[]
    for i in range(listSize):
        traindata=Trainlist[i];
        diffvalue=np.sum(np.square(traindata-testdata))
        diff=np.sqrt(diffvalue)
        diffs.append(diff)
    sortIndex=np.argsort(diffs)
    #sortIndex  argsort對所有元素進行排序,返回的是序號值
    num=[]
    for i in range(10):
        num.append(0)
    for i in range(k):
        num[int(classLable[sortIndex[i]])]+=1;
#    找出出現頻率最多的數
    s=np.argsort(num)
    return s[9]

#讀取並且處理文件 相當於main方法 在這裏調用其他方法
def Read():
    hwlable=[]
    # 將讀入的數據32*32轉換爲1024*length的數組
    Trainlist=os.listdir('trainingDigits')
    length=len(Trainlist)
    trainMat=np.zeros((length,1024))

    for i in range (length):
        # 讀取文件名
        filename=Trainlist[i]
        filestr=filename.split(".")[0]
        #通過字符串分割,得到數字
        classNum=filestr.split('_')[0]
        hwlable.append(classNum)
        trainMat[i:]=toVerctor('trainingDigits/%s'%filename)
    # 測試集
    # 測試文件 循環比較
    testFileList=os.listdir('testDigits')
    errorCount=0;
    TestLength=len(testFileList)
    for i in range(TestLength):
        filenamestr=testFileList[i]
        filestr=filenamestr.split(".")[0]
        classStr=filestr.split("_")[0]
        # 測試向量
        testVector=toVerctor('testDigits/%s'%filenamestr)
        lable=Classfiy(trainMat,testVector,hwlable,5)
        if lable!=int(classStr):
            errorCount+=1
            print('false'+str(lable)+":"+classStr)
    print("正確個數:"+str(TestLength-errorCount))
    print("正確率:"+str((TestLength-errorCount)/TestLength))

結果:
這裏寫圖片描述
看的出正確率還挺高的

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章