【機器學習】KNN近鄰算法

在這裏插入圖片描述

import numpy as np
import time
# from perception import data_load

def data_load(filename):
    '''
    :param filename:
    :return: dataArr,labelArr
    '''
    print('start read file')
    dataArr,labelArr = [],[]
    with open(filename,'r') as f:
        lines = f.readlines()
    for line in lines:
        line = line.strip().split(',')

        if int(line[0]) >= 5:
            labelArr.append(1)
        else:
            labelArr.append(-1)

        dataArr.append([int(num)/255 for num in line[1:]])
    print('End')
    return dataArr,labelArr

def cal_distance(x1,x2):
    '''

    :param x1:
    :param x2:
    :return: 兩點之間的歐式距離
    '''
    x1 = np.array(x1)
    x2 = np.array(x2)
    return np.sqrt(np.sum(np.square(x1-x2)))

def knn(traindata,trainlabel,x,k):
    '''

    :param traindata: 訓練集的數據
    :param trainlabel: 訓練集的標籤
    :param x: 目標點
    :param k: k近鄰的k
    :return: 目標點的預測標籤
    '''
    dis_list = []
    for i in range(len(traindata)):
        dis_list.append(cal_distance(traindata[i],x))

    #sort
    klist = np.argsort(np.array(dis_list))[:k]
    klist_label = [trainlabel[key] for key in klist]
    klabel = [klist_label.count(key) for key in list(set(klist_label))]

    return list(set(klist_label))[np.argsort(np.array(klabel))[-1]]

def model_test(traindata,trainlabel,testdata,testlabel,k):
    sum = len(traindata)
    rigSum = 0
    # for i in range(len(testdata)):
    for i in range(200):
        print('iter:{}'.format(i))
        y_pred = knn(traindata,trainlabel,testdata[i],k)
        if testlabel[i] == y_pred:
            rigSum += 1

    return rigSum/200*100

if __name__ == '__main__':
    k = 25
    traindata,trainlabel = data_load('dataset/mnist_train.csv')
    testdata,testlabel = data_load('dataset/mnist_test.csv')
    start = time.time()
    print(model_test(traindata,trainlabel,testdata,testlabel,k))
    end = time.time()
    print('訓練時間:{}'.format(end-start))

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章