ML入门1.0 -- 手写KNN

全文内容

KNN简介

KNN 全称为 K Nearest Neighbors 中文又称 K- 近邻算法 ,是一种用与分类和回归的非参数统计的方法。KNN采用向量空间模型来分类,概念为相同类别的案例,彼此的相似度高,而可以借由计算与已知类别案例之相似度,来评估未知类别案例可能的分类。
KNN是一种基于实例的学习,或者是局部近似和将所有计算推迟到分类之后的惰性学习。k-近邻算法是所有的机器学习算法中最简单的之一。

写作背景

写这篇博客呢,主要是因为最近学校开了 ML 的课程,记录一下学习过程中的一些想法和踩过的坑,方便大家一起交流学习。

算法原理

KNN常用于有监督学习,作为一个分类算法,那么它的主要作用就是确定同一事物中的不同类别,例如区分三种鸾尾花,将不同身高的人类进行分类···
KNN算法的工作原理也很简单:给定一个具有若干数据的数据集,该数据集分为两个部分— 训练样本a & 测试样本b训练样本a是一些既有特征向量(特征数据)标签(类别) 的数据集合;测试样本b 是只有特征向量未知其类别的数据集合,那么确定测试样本b中的单个样本是基于该样本与训练样本a中样本的特征向量的某种距离度量,在训练样本a中找出K个与该样本距离度量最小的样本,之后基于这K个样本的标签信息进行投票,选择K个样本中出现最多的类别作为该单个测试样本标签(类别)的预测值PS:一般情况下K<=20; 距离度量选择欧式距离。

欧式距离计算公式:
d=i=0n(XiYi)2 d = \sqrt{\sum_{i=0}^{n}(Xi-Yi)^2}

以上解释涉及少量专业术语,考虑本人表述能力较差建议选择性阅读,这里给出一个非常生动的解释:
【数学】一只兔子帮你理解 kNN - 宏观经济算命椰的文章 - 知乎
link.

算法结构

Step1: 加载数据集
Step2: 分类
- 2.1 找到k个近邻
- 2.2 投票

KNN实现

这里的代码是用Python实现:
- 数据集选用:sklearn.datasets.load_iris (150个样本,3个类别)

在这里插入图片描述

Func1(): Loaddata() 读取数据集数据并划分训练集(120)和 测试集 (30)

# Step1 load iris data
def Loaddata():
    '''
    :tempDataset: the returned Bunch object
    :X: 150 flowers' data (花朵的特征数据)
    :Y: 150 flowers' label (花朵的种类标签)
    :return: X1(train_set), Y1(train_label), X2(test_set), Y2(test_label)
    '''

    tempDataset = sklearn.datasets.load_iris()
    X = tempDataset.data
    Y = tempDataset.target
    # Step2 split train set & test set
    X1, X2, Y1, Y2 = sklearn.model_selection.train_test_split(X, Y, test_size=0.2)
    return X1, Y1, X2, Y2

Func2(): euclideanDistance(x, y) 计算单个样本之间的欧式距离

def euclideanDistance(x, y):
    '''
    计算欧式距离
    :param x: 某一朵花的数据
    :param y: 另一朵花的数据
    :return: 两个花朵数据之间的欧式距离
    '''
    tempDistance = 0
    m = x.shape[0]
    for i in range(m):
        tempDifference = x[i] - y[i]
        tempDistance += tempDifference * tempDifference
    return tempDistance**0.5

Func3(): stKnnClassifierTest(X1, Y1, X2, Y2, K = 5)
)
不当调包侠 手写KNN分类器

# Step3 Classify
def stKnnClassifierTest(X1, Y1, X2, Y2, K = 5):
    '''
    :param X1: train_set
    :param Y1: train_label
    :param X2: test_set
    :param Y2: test_label
    :param K: 所选的neighbor的数量
    :return: no return
    '''
    tempStartTime = time.time()
    tempScore = 0
    test_Instances = Y2.shape[0]
    train_Instances = Y1.shape[0]
    print('the num of testInstances = {}'.format(test_Instances))
    print('the num of trainInstances = {}'.format(train_Instances))
    tempPredicts = np.zeros((test_Instances))

    for i in range(test_Instances):
        # tempDistacnes = np.zeros((test_Instances))

        # Find K neighbors
        tempNeighbors = np.zeros(K + 2)
        tempDistances = np.zeros(K + 2)

        for j in range(K + 2):
            tempDistances[j] = 1000
        tempDistances[0] = -1

        for j in range(train_Instances):
            tempdis = euclideanDistance(X2[i], X1[j])
            tempIndex = K
            while True:
                if tempdis < tempDistances[tempIndex]:
            # prepare move forward
                    tempDistances[tempIndex + 1] = tempDistances[tempIndex]
                    tempNeighbors[tempIndex + 1] = tempNeighbors[tempIndex]
                    tempIndex -= 1
            #insert
                else:
                    tempDistances[tempIndex + 1] = tempdis
                    tempNeighbors[tempIndex + 1] = j
                    break

        # Vote
        tempLabels = []
        for j in range(K):
            tempIndex = int(tempNeighbors[j + 1])
            tempLabels.append(int(Y1[tempIndex]))

        tempCounts = []
        for label in tempLabels:
            tempCounts.append(int(tempLabels.count(label)))
        tempPredicts[i] = tempLabels[np.argmax(tempCounts)]

    # the rate of correct classify
    tempCorrect = 0
    for i in range(test_Instances):
        if tempPredicts[i] == Y2[i]:
            tempCorrect += 1

    tempScore = tempCorrect / test_Instances

    tempEndTime = time.time()
    tempRunTime = tempEndTime - tempStartTime

    print(' ST KNN score: {}%, runtime = {}'.format(tempScore*100, tempRunTime))

运行结果

在这里插入图片描述

完整代码见github

link

优缺点

优点
1.简单,易于理解,易于实现,无需估计参数;
2. 适合对稀有事件进行分类;
3.特别适合于多分类问题(multi-modal,对象具有多个类别标签
缺点
1.该算法在分类时有个主要的不足是,当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。 该算法只计算“最近的”邻居样本,某一类的样本数量很大,那么或者这类样本并不接近目标样本,或者这类样本很靠近目标样本。无论怎样,数量并不能影响运行结果。
2.该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章