KD樹算法

與傳統的KNN算法比較我感覺慢很多,我的姿勢是不是不對

kd樹

import numpy as np
from numpy import *

class KDNode():
    """
    KDNode
    point:該節點的樣本點
    split:用於判斷分割的維度(屬性)
    left:左節點
    right:右節點
    """
    def __init__(self, point=None, split=None, left=None, right=None):
        self.point = point
        self.split = split
        self.left = left
        self.right = right

class KDTree():
    """
    KD樹
    KDNode:kd-tree的節點
    dimensions:數據的緯度
    right:節點的右子節點
    left:節點的左子節點
    curr_axis:當前需要切分的緯度
    next_axis:下一次需要切分的緯度
    """
    def __init__(self,data=None):
        """
        採用遞歸的方式創建樹
        """
        def createNode(split=None, data_set=None):
            """
            遞歸創建節點
            input:(1)split:分割的維度(2)data_set:需要分割的樣本點集合
            output:KDnode
            """
            if len(data_set) == 0:
                return None # 數據集爲空,作爲遞歸的停止條件
            # 按照split對data_set進行排序,找到split維度中的中位數
            data_set = list(data_set)
            data_set.sort(key=lambda x: x[split]) # 按照split維的數據大小排序
            data_set = np.array(data_set)
            median = len(data_set) // 2 # 不用python自帶的median函數,我返回的是median的位置所在的索引
            # data_set[median]就是這個節點的樣本點
            # split是這個節點的分割維度
            # data_set[:median]樣本節點左半部分 data_set[median-1:]
            print("------------",median)
            print('data_set[:median]',data_set[:median])
            print('data_set[median+1:]',data_set[median+1:])
            return KDNode(data_set[median],split,createNode(maxVar(data_set[:median]),data_set[:median]),createNode(maxVar(data_set[median+1:]),data_set[median+1:]))

        def maxVar(data_set=None):
            """
            計算樣本集的最大方差維度
            input:data_set樣本集
            output:split:最大方差的維度,作爲createNode的輸入值
            """
            if len(data_set)==0:
                return 0
            print("======",len(data_set))
            data_mean = np.mean(data_set,0) # 按照列求平均值
            print(data_mean)
            mean_differ = data_set - data_mean # 求均值差
            data_var = np.sum(mean_differ ** 2, axis=0)/len(data_set) # 求方差,差反映數據的分散特徵,方差的數值越大,那麼數據的分散程度越大
            re = np.where(data_var == np.max(data_var)) # 尋找方差最大的位置
            print("re:",re)
            return re[0][0] # 方差最大的維數
        # print(data)
        self.root = createNode(maxVar(data),data)

def computeDist(pt1,pt2):
    """
    計算兩個點之間的距離
    點的類型是N維的
    """
    sum = 0.0
    for i in range(len(pt1)):
        sum = sum + (pt1[i] - pt2[i]) ** 2
    return np.math.sqrt(sum)

def preOrder(root):
    """
    前序遍歷KD樹
    """
    print(root.point)
    if root.left:
        preOrder(root.left)
    if root.right:
        preOrder(root.right)

def updateNN(min_dist_array=None, tmp_dist=0.0, NN=None, tmp_point=None, k=1):
    """
    更新近鄰點和對應的最小距離的集合
    min_dist_array爲最小距離的集合
    NN爲鄰近點的集合
    tmp_dist和tmp_point分別是需要更新到min_dist_array,NN裏的近鄰點和距離
    """
    # 如果距離更小就更新min_dist_array
    if tmp_dist <= np.min(min_dist_array):
        # 刪除最大距離和對應的節點
        for i in range(k-1,0,-1):
            min_dist_array[i] = min_dist_array[i-1]
            NN[i] = NN[i-1]

        min_dist_array[0] = tmp_dist
        NN[0] = tmp_point
        return NN,min_dist_array
    for i in range(k) :
        if (min_dist_array[i] <= tmp_dist) and (min_dist_array[i+1] >= tmp_dist) :
            #tmp_dist在min_dist_array的第i位和第i+1位之間,則插入到i和i+1之間,並把最後一位給剔除掉
            for j in range(k-1,i,-1) : #range反向取值
                min_dist_array[j] = min_dist_array[j-1]
                NN[j] = NN[j-1]
            min_dist_array[i+1] = tmp_dist
            NN[i+1] = tmp_point
            break
    return NN,min_dist_array

def searchKDTree(KDTree=None, target_point=None, k=1):
    """
    搜索KD樹
    input:KDTree:kd樹;target_point:目標點;k:距離目標點最近的k個點的k值
    output:k_arrayList,距離目標點最近的k個點的集合數組
    """
    if k==0 : return None

    tempNode = KDTree.root # 從更節點出發
    NN = [tempNode.point] * k #定義最鄰近點集合,k個元素,按照距離遠近,由近到遠。初始化爲k個根節點
    min_dist_array = [float("inf")] * k#定義近鄰點與目標點距離的集合.初始化爲無窮大
    nodeList = []

    def buildSearchPath(tempNode=None, nodeList=None,min_dist_array=None,NN=None,target_point=None):
        """
        此方法是用來建立以tempNode爲根節點,以下所有節點的查找路徑,並將它們存放到nodeList中
        nodeList爲一系列節點的順序組合,按此先後順序搜索最鄰近點
        tempNode爲"根節點",即以它爲根節點,查找它以下所有的節點(空間)
        """
        while tempNode:
            nodeList.append(tempNode)
            split = tempNode.split
            point = tempNode.point
            tmp_dist = computeDist(point,target_point)
            if tmp_dist < np.max(min_dist_array):
                NN,min_dist_array = updateNN(min_dist_array,tmp_dist,NN,point,k)# 更新最小距離和最近鄰近點
            if target_point[split] <= point[split]:#如果目標點當前維的值小於等於切分點的當前維座標值,移動到左節點
                tempNode = tempNode.left
            else:
                tempNode.right
        return NN,min_dist_array


    # 建立查找路徑
    NN,min_dist_array = buildSearchPath(tempNode,nodeList,min_dist_array, NN, target_point)
    # 回溯查找
    while nodeList:
        back_node = nodeList.pop()
        split = back_node.split
        point = back_node.point
        #判斷是否需要進入父節點搜素
        #如果當前緯度,目標點減實例點大於最小距離,就沒必要進入父節點搜素了
        #因爲目標點到切割超平面的距離很大,那鄰近點肯定不在那個切割的空間裏,即沒必要進入那個空間搜素了
        if not abs(target_point[split] - point[split]) >= np.max(min_dist_array):
            if target_point[split] <= point[split]: # 在右側
                tempNode = back_node.right
            else:
                tempNode = back_node.left # 在左側
            if tempNode:
                NN,min_dist_array = buildSearchPath(tempNode,nodeList,min_dist_array, NN, target_point)
    return NN,min_dist_array

def classify0(inX, dataSet, labels, k):
    '''
    k近鄰算法的分類器
    input:
    inX:目標點
    dataSet:訓練點集合
    labels:訓練點對應的標籤
    k:k值
    這個方法的目的:已知訓練點dataSet和對應的標籤labels,確定目標點inX對應的labels
    ''' 
    kd = KDTree(dataSet)#構建dataSet的kd樹
    NN,min_dist_array = searchKDTree(kd, inX, k)#搜索kd樹,返回最近的k個點的集合NN,和對應的距離min_dist_array
    dataSet = dataSet.tolist()
    voteIlabels = []
    #多數投票法則確定inX的標籤,爲防止邊界處分類不準的情況,以距離的倒數爲權重,即距離越近,權重越大,越該認爲inX是屬於該類
    for i in range(k) :
        #找到每個近鄰點對應的標籤
        nni = list(NN[i])
        voteIlabels.append(labels[dataSet.index(nni)])

#     #開始記數,加權重的方法
#     uniques = np.unique(voteIlabels)
#     counts = [0.0] * len(uniques)
#     for i in range(len(voteIlabels)) :
#         for j in range(len(uniques)) :
#             if voteIlabels[i] == uniques[j] :
#                 counts[j] = counts[j] + uniques[j] / min_dist_array[i] #權重爲距離的倒數
#                 break
    #開始記數,不加權重的方法
    uniques, counts = np.unique(voteIlabels, return_counts=True)
    return uniques[np.argmax(counts)]

# 處理文件數據
def file2matrix(filename):
    fr = open(filename) # 打開文件
    arrayOlines = fr.readlines() #讀取文件
    numbersOfLines = len(arrayOlines) # 文件有多少行
    returnMat = zeros((numbersOfLines,3)) # 創建0矩陣
    classLabelVector = [] # 標籤集合
    index = 0
    for line in arrayOlines:
        line = line.strip()#移除字符串頭尾的空格
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3] # 取前三個數據然後給切片賦值
        classLabelVector.append(int(listFromLine[-1])) # 最後一個是標籤
        index += 1
    return returnMat,classLabelVector

# 歸一化特徵值
def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minVals,(m,1))
    normDataSet = normDataSet/tile(ranges,(m,1))
    return normDataSet,ranges,minVals


def datingClassTest():
    hoRatio = 0.1 # 測試樣本的比例
    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') # 載入數據
    normMat,ranges,minVals = autoNorm(datingDataMat) # 歸一化處理
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio) # 獲取測試樣本
    errorCount = 0.0
    print(type(datingDataMat))
    print(type(datingLabels))
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
        if (classifierResult != datingLabels[i]): errorCount += 1.0
    print("the total error rate is: %f" % (errorCount/float(numTestVecs)))
    print(errorCount)

if __name__ == "__main__":
    # test()
    # test2()
    datingClassTest()
發佈了163 篇原創文章 · 獲贊 41 · 訪問量 20萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章