knn算法原理與實現(2)kd樹算法原理和python實現

一、kd樹算法分爲兩步,第一步是構建平衡kd樹,第二部是搜索預測數據的最近鄰

二、構建kd樹

輸入:k維空間數據集T = {X_{1},X_{2}...X_{n}},其中X_{i} = {x_{i}^{1},x_{i}^{2}...x_{i}^{k}},特徵維度k,訓練樣本數維n

輸出:kd樹

從第1個特徵到第k個特徵,每次選擇一個特徵,找出該特徵取值的中位數,以此特徵的中位數劃分超平面,每次劃分都是在之前劃分的基礎進行的,也就是在上次劃分的每個子區間選擇下一特徵進行劃分,當特徵用完了,則重新從第一個特徵開始劃分,直到區域內無實例爲止,即每個樣本都在所劃分的超平面上。

這裏我採用遞歸構建kd樹,因爲實現比較簡單,構建出來一顆平衡二叉樹,也可以用B樹,詳見代碼中的buildTree函數。

三、kd樹的最近鄰搜索

輸入:已構建的kd樹,目標點X

輸出:X的最近鄰

a.在kd樹種找出包含目標點X的葉節點,方法是從根節點出發遞歸訪問kd樹,若歐氏距離小於切分點則訪問左子樹,否則訪問右子樹(平衡二叉樹搜索算法),直到找到相應的葉節點

b.從葉節點回溯,檢查被回溯的節點與X的歐氏距離,如果小於原來距離,則將其設爲最緊鄰,否則查看其另一子節點比較。循環回溯,直到根節點。

關於二叉樹的遍歷與回溯一般採用棧,不清楚的可以翻翻經典的《數據結構》

最後將完全代碼貼於下發,歡飲批評指正

__author__ = 'Gujun(Bill) '
# kd樹生成與搜索
#2018/11/05

import numpy as np

def countDistance(x1,x2):#計算歐氏 距離
    dim = len(x1)
    distance = 0
    for i in range(dim):
        distance += (x1[i]-x2[i])*(x1[i]-x2[i])
    return np.sqrt(distance)

class Node:
    def __int__(self, data, left, right,parent):
        self.data = data
        self.left = left
        self.right = right
        #self.parent = parent


def getCharNum(dataMat):  #獲取特徵數
    return dataMat.shape[1]


# def buildTree(dataMat, aproch, k,parent):  #構建kd樹,帶父節點
#     if dataMat.shape[0] > 0
#         sorted(dataMat, key=dataMat[:][aproch % (k - 1)])  #對數據排序並改變了矩陣
#         left_mat = dataMat[:][0:k / 2]
#         right_mat = dataMat[:][k / 2 + 1, :]
#         node = Node()
#         node.data = dataMat[dataMat.shape[0] / 2]
#         node.left = buildTree(left_mat, aproch + 1, k,node)
#         node.right = buildTree(right_mat, aproch + 1,k, node)
#         node.parent = parent
#          #遞歸構建kd樹
#         #node.parent = parent
#     else:
#         node = None
#     return node
def buildTree(dataMat, aproch, k):  #構建kd樹
    if dataMat.shape[0] > 0
        sorted(dataMat, key=dataMat[:][aproch % (k - 1)])  #對數據排序並改變了矩陣
        left_mat = dataMat[:][0:k / 2]
        right_mat = dataMat[:][k / 2 + 1, :]
        node = Node(dataMat[dataMat.shape[0] / 2],left_mat,right_mat)
         #遞歸構建kd樹
        #node.parent = parent
    else:
        node = None
    return node


def searchKdTree(node, inputVec, aproch, k,stack):#用堆棧或者建樹時保存父節點
    if node[aproch % (k - 1)] > inputVec[aproch % (k - 1)] and node.right != None:
        stack.append(node)
        return searchKdTree(node.right, inputVec, aproch + 1, k)
    elif node[aproch % (k - 1)] <= inputVec[aproch % (k - 1)] and node.right != None:
        stack.append(node)
        return searchKdTree(node.left, inputVec, aproch + 1, k)
    else:
        #node =
        return node

def revSearch(stack,inputVec): #逆向搜索
    minDistance = 655345
    node = Node()
    minNode = Node()
    minNode = stack[-1]
    while len(stack) > 0:
        distance = countDistance(node.data[:,-1],inputVec)#計算本身節點
        if distance < minDistance:
            minDistance = distance
            minNode = node
        if node.right != None:
            distance =  countDistance(node.right.data[:,-1],inputVec)#計算右子節點距離
            if distance < minDistance:
                minDistance = distance
                minNode = node.right
        stack.pop()#彈出最後一個元素繼續回溯
    return minNode,minDistance

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章