Kd树+BBF(最邻近、次邻近查询)Python实现

 python2.7

import numpy as np

 构建Kd树:

KD树的构造 
一维的二叉查找树很好构造,先对所有数据排序,然后每次取中值,把数据分成两半,左半为左子树,右半为右子树;然后递归下去就好了。这样可以保证构造出来的二叉树是平衡的。 
KD树处理的数据是多维的,因此每次划分需要选定某一维作为参考来划分数据。选定后所有数据按这一维排序,然后划分成左子树,右子树。参考维度的选定可以依次选,比如这一层以X维划分,下一层就以Y维,如此循环反复。更好的方法是每次选择方差最大的那一维。只要划分以后左右区域都还有数据,划分就进行下去,直到按某个节点划分完以后两边没有数据点为止。

# kd-tree每个结点中主要包含的数据结构如下
class KdNode(object):
    def __init__(self, dom_elt, split, left, right):
        self.dom_elt = dom_elt  # k维向量节点(k维空间中的一个样本点)
        self.split = split  # 整数(进行分割维度的序号)
        self.left = left  # 该结点分割超平面左子空间构成的kd-tree
        self.right = right  # 该结点分割超平面右子空间构成的kd-tree


class KdTree(object):
    def __init__(self, data):
        k = len(data[0])  # 数据维度

        def CreateNode(split, data_set):  # 按第split维划分数据集exset创建KdNode
            if not data_set:  # 数据集为空
                return None
            # key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较
            # operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为需要获取的数据在对象中的序号
            # data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序
            data_set.sort(key=lambda x: x[split])
            split_pos = len(data_set) // 2  # //为Python中的整数除法
            median = data_set[split_pos]  # 中位数分割点
            split_next = (split + 1) % k  # cycle coordinates

            # 递归的创建kd树
            return KdNode(median, split,
                          CreateNode(split_next, data_set[:split_pos]),  # 创建左子树
                          CreateNode(split_next, data_set[split_pos + 1:]))  # 创建右子树

        self.root = CreateNode(0, data)  # 从第0维分量开始构建kd树,返回根节点

BBF查询:

#kdTree_bbf
class Prioritylist(object):
    def __init__(self,kdnode,priority):
        self.node=kdnode
        self.priority=priority

prioritylist=[]#存放Prioritylist p1

def InsertPriorityList(kdnode,priority):
    p1=Prioritylist(kdnode,priority)
    if len(prioritylist)==0:
        prioritylist.append(p1)
        return
    for i in range(len(prioritylist)):
        if prioritylist[i].priority>=priority:
            prioritylist.insert(i,p1)
            break
        else:
            prioritylist.append(p1)
            break

def RemovePriority(kdnode):
    for i in range(len(prioritylist)):
        if prioritylist[i].node.dom_elt==kdnode.dom_elt:
            prioritylist.pop(i)
            break

#优先级的计算,计算目标点和分割点之间的距离(某一维度),即优先级
def CalPriority(kdnode,target,split):
    return abs(kdnode.dom_elt[split]-target[split])

def CalDistance(vector1,vector2):
    return ((np.array(vector1)-np.array(vector2))**2).sum()**0.5


def BBFFindNearest(kdnode,target):
    nearest=kdnode
    sec_near=float("inf")
    priority=CalPriority(nearest,target,kdnode.split)
    InsertPriorityList(nearest,priority)
    top_node=None
    currentNode=None
    fir_dis=CalDistance(nearest.dom_elt,target)
    sec_dis=0
    while len(prioritylist)>0:
        top_node=prioritylist[0].node
        RemovePriority(top_node)
        while top_node!=None:
            if top_node.left!=None or top_node.right!=None:
                split=top_node.split
                if target[split]<=top_node.dom_elt[split]:
                    if top_node.right!=None:
                        priority=CalPriority(top_node.right,target,top_node.split)
                        InsertPriorityList(top_node.right,priority)
                    top_node=top_node.left
                else:
                    if top_node.left!=None:
                        priority=CalPriority(top_node.left,target,top_node.split)
                        InsertPriorityList(top_node.left,priority)
                    top_node=top_node.right
                currentNode=top_node
            else:
                currentNode=top_node
                top_node=None
            if currentNode!=None and (CalDistance(nearest.dom_elt,target)>CalDistance(currentNode.dom_elt,target)):
                sec_near=nearest
                nearest=currentNode
                fir_dis = CalDistance(nearest.dom_elt, target)
                sec_dis=CalDistance(sec_near.dom_elt,target)
            elif currentNode!=None and (CalDistance(nearest.dom_elt,target)<CalDistance(currentNode.dom_elt,target)):
                if sec_near==float("inf"):
                    sec_near=currentNode
                    sec_dis = CalDistance(sec_near.dom_elt, target)
                else:
                    if CalDistance(sec_near.dom_elt,target)>CalDistance(currentNode.dom_elt,target):
                        sec_near=currentNode
                        sec_dis=CalDistance(sec_near.dom_elt,target)
    return nearest,sec_near,fir_dis,sec_dis

测试:

# KDTree的前序遍历
def preorder(root):
    print root.dom_elt
    if root.left:  # 节点不为空
        preorder(root.left)
    if root.right:
        preorder(root.right)

if __name__=='__main__':
    data=[[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
    kd=KdTree(data)
    preorder(kd.root)
    nearest,sec_near=BBFFindNearest(kd.root,[0,0])
    print 'The nearest point is:',nearest.dom_elt,',Distance is:',firdis
    print 'The second point is:',sec_near.dom_elt,',Distance is:',secdis

测试结果:

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章