K-nn手写数字识别--Python版

模式识别的实验作业,弄了一个晚上终于在第二天中午弄明白了!

简单来说,k-nn就是通过计算训练集和 一个测试数据之间的欧式距离,然后将计算结果按照从小到大来排序,找出最小的k个数据,分析k个数据中哪种情况出现的频率最多,那么这个测试数据就属于这一类

  1. 思路

  2. 读入数据,假设100个训练数据,将训练数据转换为100*1024的二维数组,然后循环读入测试数据,计算测试数据和100个训练数据间的欧式距离:
    欧式距离公式
    x1-xn为单个训练数据的所有元素,y1-yn为测试数据的所有元素

    这样就得到一个数组,包含所有训练数据和测试数据的欧式距离,将距离从小到大进行排序。

3. 结果

找出k个最近的距离,看哪个数字出现的频率最多,那么这个测试数据大概率为这个数字

#解压文件
def JY():
    path="/Users/fanjialiang2401/PycharmProjects/模式识别/digits.zip"
    newpath="/Users/fanjialiang2401/PycharmProjects/模式识别/"
    f=zipfile.ZipFile(path,'r')
    for  file in f.namelist():
        f.extract(file,newpath)
    print("success!")
#     将32*32矩阵转换为一个长为1024的一位数字
def toVerctor(filename):
    returnVect=np.zeros((1,1024))
    fr=open(filename)
    for i in range(32):
        linestr=fr.readline();
        for j in range(32):
            returnVect[0,32*i+j]=int(linestr[j])
    return returnVect;
# 测试 trainlist为训练集所有数据,testdata为测试数据 classLable为
def Classfiy(Trainlist,testdata,classLable,k):
    listSize=len(Trainlist)
    diffs=[]
    for i in range(listSize):
        traindata=Trainlist[i];
        diffvalue=np.sum(np.square(traindata-testdata))
        diff=np.sqrt(diffvalue)
        diffs.append(diff)
    sortIndex=np.argsort(diffs)
    #sortIndex  argsort对所有元素进行排序,返回的是序号值
    num=[]
    for i in range(10):
        num.append(0)
    for i in range(k):
        num[int(classLable[sortIndex[i]])]+=1;
#    找出出现频率最多的数
    s=np.argsort(num)
    return s[9]

#读取并且处理文件 相当于main方法 在这里调用其他方法
def Read():
    hwlable=[]
    # 将读入的数据32*32转换为1024*length的数组
    Trainlist=os.listdir('trainingDigits')
    length=len(Trainlist)
    trainMat=np.zeros((length,1024))

    for i in range (length):
        # 读取文件名
        filename=Trainlist[i]
        filestr=filename.split(".")[0]
        #通过字符串分割,得到数字
        classNum=filestr.split('_')[0]
        hwlable.append(classNum)
        trainMat[i:]=toVerctor('trainingDigits/%s'%filename)
    # 测试集
    # 测试文件 循环比较
    testFileList=os.listdir('testDigits')
    errorCount=0;
    TestLength=len(testFileList)
    for i in range(TestLength):
        filenamestr=testFileList[i]
        filestr=filenamestr.split(".")[0]
        classStr=filestr.split("_")[0]
        # 测试向量
        testVector=toVerctor('testDigits/%s'%filenamestr)
        lable=Classfiy(trainMat,testVector,hwlable,5)
        if lable!=int(classStr):
            errorCount+=1
            print('false'+str(lable)+":"+classStr)
    print("正确个数:"+str(TestLength-errorCount))
    print("正确率:"+str((TestLength-errorCount)/TestLength))

结果:
这里写图片描述
看的出正确率还挺高的

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章