機器學習實戰_K近鄰算法 —— 電影分類

一、數據參考

在這裏插入圖片描述

二、代碼

import numpy as np
import operator


def createDataSet():
    """
    函數說明:創建數據集

    Parameters:
        None

    Returns:
        group - 數據集
        labels - 分類標籤

    """
    # 七組二維特徵
    group = np.array([[3, 104],
                      [2, 100],
                      [1, 81],
                      [101, 10],
                      [99, 5],
                      [98, 2],
                      [18, 90]])
    # 七組特徵的標籤
    labels = ['愛情片', '愛情片', '愛情片', '動作片', '動作片', '動作片', "未知"]
    return group, labels


def classify0(inX, dataSet, labels, k):
    """
    函數說明:kNN算法,分類器

    Parameters:
        inX - 用於分類的數據(測試集)(1*m向量)
        dataSet - 用於訓練的數據(訓練集)(n*m向量array)
        labels - 分類標準(n*1向量array)
        k - kNN算法參數,選擇距離最小的k個點

    Returns:
        sortedClassCount[0][0] - 分類結果

    """
    # numpy函數shape[0]獲取dataSet的行數
    dataSetSize = dataSet.shape[0]
    # 將inX重複dataSetSize次並排成一列,即將inX賦值dataSetSize行、1列
    diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet   # tile:複製函數
    # 矩陣數乘:矩陣對應位置元素相乘(array()函數中矩陣的乘積可以使用np.matmul或者.dot()函數。而星號乘 (*)則表示矩陣對應位置元素相乘,與numpy.multiply()函數結果相同)
    sqDiffMat = diffMat ** 2  # 每個元素 ** 2
    # sum()所有元素相加,sum(0)列相加,sum(1)行相加
    sqDistances = sqDiffMat.sum(axis=1)
    # 開方,計算出距離
    distances = sqDistances ** 0.5  # 每個元素 ** 0.5
    # argsort函數返回的是distances值從小到大排序後的索引值
    sortedDistIndicies = distances.argsort()
    # 定義一個記錄類別次數的字典
    classCount = {}
    # 選擇距離最小的k個點
    for i in range(k):
        # 取出前k個元素的類別
        voteIlabel = labels[sortedDistIndicies[i]]
        # 字典的get()方法,返回指定鍵的值,如果值不在字典中返回0
        # 計算類別次數
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    # python3中用items()替換python2中的iteritems()
    # key = operator.itemgetter(1)根據字典的值進行排序
    # key = operator.itemgetter(0)根據字典的鍵進行排序
    # reverse降序排序字典
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    print("sortedClassCount:", sortedClassCount)
    # 返回次數最多的類別,即所要分類的類別
    return sortedClassCount[0][0]


if __name__ == '__main__':
    group, labels = createDataSet()

    result = classify0([70, 5], group, labels, 3)
    print(result)

    result = classify0([9, 79], group, labels, 3)
    print(result)

三、運行結果

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章