KNN算法 數字識別

import os
import numpy as np


def data_trans(dir_path):
    file_list = os.listdir(dir_path)
    # print(file_list)
    # 聲明一個大數組 存放所有文件的數組和標籤
    big_arr = np.zeros((len(file_list), 1025))
    for i, file in enumerate(file_list):
        # 標籤
        flag = file[0]
        # print(flag)
        # 拼接文件路徑
        file_path = dir_path + '/' + file
        # print(file_path)
        # 讀取文件  一維
        file_arr = np.loadtxt(file_path, dtype=str)
        # print(file_arr)
        # 用來存放每個文件中的數組
        arr = np.zeros((32, 32))
        for j, num in enumerate(file_arr):
            arr[j] = list(map(int, num))
        # print(arr)
        # 將arr展平成1*1024
        arr_ravel = arr.ravel()
        # print(arr_ravel)
        big_arr[i, 0:-1] = arr_ravel
        # 最後一列用來存放標籤
        big_arr[i, -1] = flag
        # break
    name = dir_path.split('/')[-1]
    # print(big_arr.shape)
    np.savetxt("{}.csv".format(name), big_arr, fmt='%d')


if __name__ == '__main__':
    dir_path1 = "./digits/trainingDigits"
    dir_path2 = "./digits/testDigits"
    data_trans(dir_path1)
    data_trans(dir_path2)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def knn(train_digit, test_digit, k):
    true_num = 1
    # 求測試集的每一行的相似度
    for i in range(test_digit.shape[0]):
        d = np.sqrt(((test_digit[i, :-1]-train_digit[:, :-1])**2).sum(axis=1))
        sort_index = d.argsort()[:k]
        flag = train_digit[sort_index, -1]
        df = pd.DataFrame(flag).mode()
        print("預測值: ", df[0][0])
        print("真實值: ", test_digit[i, -1])
        if df[0][0] == test_digit[i, -1]:
            true_num += 1
    print('準確度爲:', true_num/test_digit.shape[0])
    return true_num/test_digit.shape[0]


if __name__ == '__main__':
    train_digit = np.loadtxt("./trainingdigits.csv")
    test_digit = np.loadtxt("./testdigits.csv")
    y = []
    for k in range(5, 15):
        prec = knn(train_digit, test_digit, k)
        y.append(prec)
    x = range(5, 15)

    print(x)
    print(y)
    plt.figure()
    plt.plot(x, y, marker='*', markersize=12)
    plt.xlabel('k')
    plt.ylabel('precision')
    plt.show()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章