搜索Kd樹

使用KD樹進行最近鄰查找的例子

例1:

查詢點(2.1,3.1)

星號表示要查詢的點(2.1,3.1)。通過二叉搜索,順着搜索路徑很快就能找到最鄰近的近似點,也就是葉子節點(2,3)。而找到的葉子節點並不一定就是最鄰近的,最鄰近肯定距離查詢點更近,應該位於以查詢點爲圓心且通過葉子節點的圓域內。爲了找到真正的最近鄰,還需要進行相關的‘回溯'操作。也就是說,算法首先沿搜索路徑反向查找是否有距離查詢點更近的數據點。

 

    以查詢(2.1,3.1)爲例:

 

  1. 二叉樹搜索:先從(7,2)點開始進行二叉查找,然後到達(5,4),最後到達(2,3),此時搜索路徑中的節點爲<(7,2),(5,4),(2,3)>,首先以(2,3)作爲當前最近鄰點,計算其到查詢點(2.1,3.1)的距離爲0.1414,
  2. 回溯查找:在得到(2,3)爲查詢點的最近點之後,回溯到其父節點(5,4),並判斷在該父節點的其他子節點空間中是否有距離查詢點更近的數據點。以(2.1,3.1)爲圓心,以0.1414爲半徑畫圓,如下圖所示。發現該圓並不和超平面y = 4交割,因此不用進入(5,4)節點右子空間中(圖中灰色區域)去搜索;
  3. 最後,再回溯到(7,2),以(2.1,3.1)爲圓心,以0.1414爲半徑的圓更不會與x = 7超平面交割,因此不用進入(7,2)右子空間進行查找。至此,搜索路徑中的節點已經全部回溯完,結束整個搜索,返回最近鄰點(2,3),最近距離爲0.1414。

例2:

 

查詢點(2,4.5)

 

一個複雜點了例子如查找點爲(2,4.5),具體步驟依次如下:

 

  1. 同樣先進行二叉查找,先從(7,2)查找到(5,4)節點,在進行查找時是由y = 4爲分割超平面的,由於查找點爲y值爲4.5,因此進入右子空間查找到(4,7),形成搜索路徑<(7,2),(5,4),(4,7)>,但(4,7)與目標查找點的距離爲3.202,而(5,4)與查找點之間的距離爲3.041,所以(5,4)爲查詢點的最近點;
  2. 以(2,4.5)爲圓心,以3.041爲半徑作圓,如下圖所示。可見該圓和y = 4超平面交割,所以需要進入(5,4)左子空間進行查找,也就是將(2,3)節點加入搜索路徑中得<(7,2),(2,3)>;於是接着搜索至(2,3)葉子節點,(2,3)距離(2,4.5)比(5,4)要近,所以最近鄰點更新爲(2,3),最近距離更新爲1.5;
  3. 回溯查找至(5,4),直到最後回溯到根結點(7,2)的時候,以(2,4.5)爲圓心1.5爲半徑作圓,並不和x = 7分割超平面交割,如下圖所示。至此,搜索路徑回溯完,返回最近鄰點(2,3),最近距離1.5。

    上述兩次實例表明,當查詢點的鄰域與分割超平面兩側空間交割時,需要查找另一側子空間,導致檢索過程複雜,效率下降。

 

詳細解讀:

k-d tree算法原理及實現

k-d tree即k-dimensional tree,常用來作空間劃分及近鄰搜索,是二叉空間劃分樹的一個特例。通常,對於維度爲(k),數據點數爲(N)的數據集,k-d tree適用於(N\gg2^k)的情形。

1)k-d tree算法原理

k-d tree是每個節點均爲k維數值點的二叉樹,其上的每個節點代表一個超平面,該超平面垂直於當前劃分維度的座標軸,並在該維度上將空間劃分爲兩部分,一部分在其左子樹,另一部分在其右子樹。即若當前節點的劃分維度爲d,其左子樹上所有點在d維的座標值均小於當前值,右子樹上所有點在d維的座標值均大於等於當前值,本定義對其任意子節點均成立。

1.1)樹的構建

一個平衡的k-d tree,其所有葉子節點到根節點的距離近似相等。但一個平衡的k-d tree對最近鄰搜索、空間搜索等應用場景並非是最優的。

常規的k-d tree的構建過程爲:循環依序取數據點的各維度來作爲切分維度,取數據點在該維度的中值作爲切分超平面,將中值左側的數據點掛在其左子樹,將中值右側的數據點掛在其右子樹。遞歸處理其子樹,直至所有數據點掛載完畢。

a)切分維度選擇優化

構建開始前,對比數據點在各維度的分佈情況,數據點在某一維度座標值的方差越大分佈越分散,方差越小分佈越集中。從方差大的維度開始切分可以取得很好的切分效果及平衡性。

b)中值選擇優化

第一種,算法開始前,對原始數據點在所有維度進行一次排序,存儲下來,然後在後續的中值選擇中,無須每次都對其子集進行排序,提升了性能。

第二種,從原始數據點中隨機選擇固定數目的點,然後對其進行排序,每次從這些樣本點中取中值,來作爲分割超平面。該方式在實踐中被證明可以取得很好性能及很好的平衡性。

本文采用常規的構建方式,以二維平面點((x,y))的集合(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)爲例結合下圖來說明k-d tree的構建過程。

**a)**構建根節點時,此時的切分維度爲(x),如上點集合在(x)維從小到大排序爲(2,3),(4,7),(5,4),(7,2),(8,1),(9,6);其中值爲(7,2)。(注:2,4,5,7,8,9在數學中的中值爲(5 + 7)/2=6,但因該算法的中值需在點集合之內,所以本文中值計算用的是len(points)//2=3, points[3]=(7,2))

b)(2,3),(4,7),(5,4)掛在(7,2)節點的左子樹,(8,1),(9,6)掛在(7,2)節點的右子樹。

**c)**構建(7,2)節點的左子樹時,點集合(2,3),(4,7),(5,4)此時的切分維度爲(y),中值爲(5,4)作爲分割平面,(2,3)掛在其左子樹,(4,7)掛在其右子樹。

**d)**構建(7,2)節點的右子樹時,點集合(8,1),(9,6)此時的切分維度也爲(y),中值爲(9,6)作爲分割平面,(8,1)掛在其左子樹。至此k-d tree構建完成。

上述的構建過程結合下圖可以看出,構建一個k-d tree即是將一個二維平面逐步劃分的過程。

 

 

如下爲k-d tree的構建代碼:

def kd_tree(points, depth):
    if 0 == len(points):
        return None
    cutting_dim = depth % len(points[0])
    medium_index = len(points) // 2
    points.sort(key=itemgetter(cutting_dim))
    node = Node(points[medium_index])
    node.left = kd_tree(points[:medium_index], depth + 1)
    node.right = kd_tree(points[medium_index + 1:], depth + 1)
    return node

1.2)尋找d維最小座標值點

a)若當前節點的切分維度是d

因其右子樹節點均大於等於當前節點在d維的座標值,所以可以忽略其右子樹,僅在其左子樹進行搜索。若無左子樹,當前節點即是最小座標值節點。

b)若當前節點的切分維度不是d

需在其左子樹與右子樹分別進行遞歸搜索。

如下爲尋找d維最小座標值點代碼:

def findmin(n, depth, cutting_dim, min):
    if min is None:
        min = n.location
    if n is None:
        return min
    current_cutting_dim = depth % len(min)
    if n.location[cutting_dim] < min[cutting_dim]:
        min = n.location
    if cutting_dim == current_cutting_dim:
            return findmin(n.left, depth + 1, cutting_dim, min)
    else:
        leftmin = findmin(n.left, depth + 1, cutting_dim, min)
        rightmin = findmin(n.right, depth + 1, cutting_dim, min)
        if leftmin[cutting_dim] > rightmin[cutting_dim]:
            return rightmin
        else:
            return leftmin

1.3)新增節點

從根節點出發,若待插入節點在當前節點切分維度的座標值小於當前節點在該維度的座標值時,在其左子樹插入;若大於等於當前節點在該維度的座標值時,在其右子樹插入。遞歸遍歷,直至葉子節點。

如下爲新增節點代碼:

def insert(n, point, depth):
    if n is None:
        return Node(point)
    cutting_dim = depth % len(point)
    if point[cutting_dim] < n.location[cutting_dim]:
        if n.left is None:
            n.left = Node(point)
        else:
            insert(n.left, point, depth + 1)
    else:
        if n.right is None:
            n.right = Node(point)
        else:
            insert(n.right, point, depth + 1)

多次新增節點可能引起樹的不平衡。不平衡性超過某一閾值時,需進行再平衡。

1.4)刪除節點

最簡單的方法是將待刪節點的所有子節點組成一個新的集合,然後對其進行重新構建。將構建好的子樹掛載到被刪節點即可。此方法性能不佳,下面考慮優化後的算法。

假設待刪節點T的切分維度爲x,下面根據待刪節點的幾類不同情形進行考慮。

a)無子樹

本身爲葉子節點,直接刪除。

b)有右子樹

在T.right尋找x切分維度最小的節點p,然後替換被刪節點T;遞歸處理刪除節點p。

c)無右子樹有左子樹

在T.left尋找x切分維度最小的節點p,即p=findmin(T.left, cutting-dim=x),然後用節點p替換被刪節點T;將原T.left作爲p.right;遞歸處理刪除節點p。

(之所以未採用findmax(T.left, cutting-dim=x)節點來替換被刪節點,是由於原被刪節點的左子樹節點存在x維度最大值相等的情形,這樣就破壞了左子樹在x分割維度的座標需小於其根節點的定義)

如下爲刪除節點代碼:

def delete(n, point, depth):
    cutting_dim = depth % len(point)
    if n.location == point:
        if n.right is not None:
            n.location = findmin(n.right, depth + 1, cutting_dim, None)
            delete(n.right, n.location, depth + 1)
        elif n.left is not None:
            n.location = findmin(n.left, depth + 1)
            delete(n.left, n.location, depth + 1)
            n.right = n.left
            n.left = None
        else:
            n = None
    else:
        if point[cutting_dim] < n.location[cutting_dim]:
            delete(n.left, point, depth + 1)
        else:
            delete(n.right, point, depth + 1)

2)最近鄰搜索

給定點p,查詢數據集中與其距離最近點的過程即爲最近鄰搜索。

如在上文構建好的k-d tree上搜索(3,5)的最近鄰時,本文結合如下左右兩圖對二維空間的最近鄰搜索過程作分析。

**a)**首先從根節點(7,2)出發,將當前最近鄰設爲(7,2),對該k-d tree作深度優先遍歷。以(3,5)爲圓心,其到(7,2)的距離爲半徑畫圓(多維空間爲超球面),可以看出(8,1)右側的區域與該圓不相交,所以(8,1)的右子樹全部忽略。

**b)**接着走到(7,2)左子樹根節點(5,4),與原最近鄰對比距離後,更新當前最近鄰爲(5,4)。以(3,5)爲圓心,其到(5,4)的距離爲半徑畫圓,發現(7,2)右側的區域與該圓不相交,忽略該側所有節點,這樣(7,2)的整個右子樹被標記爲已忽略。

**c)**遍歷完(5,4)的左右葉子節點,發現與當前最優距離相等,不更新最近鄰。所以(3,5)的最近鄰爲(5,4)。

 

如下爲最近鄰搜索代碼:

3)複雜度分析

操作	平均複雜度	最壞複雜度
新增節點	O(logn)	O(n)
刪除節點	O(logn)	O(n)
最近鄰搜索	O(logn)	O(n)

4)scikit-learn使用

scikit-learn是一個實用的機器學習類庫,其有KDTree的實現。如下例子爲直觀展示,僅構建了一個二維空間的k-d tree,然後對其作k近鄰搜索及指定半徑的範圍搜索。多維空間的檢索,調用方式與此例相差無多。

#!/usr/bin/python
# -*- coding: UTF-8 -*-
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Circle
from sklearn.neighbors import KDTree

np.random.seed(0)
points = np.random.random((100, 2))
tree = KDTree(points)
point = points[0]

# kNN
dists, indices = tree.query([point], k=3)
print(dists, indices)

# query radius
indices = tree.query_radius([point], r=0.2)
print(indices)

fig = plt.figure()
ax = fig.add_subplot(111, aspect='equal')
ax.add_patch(Circle(point, 0.2, color='r', fill=False))
X, Y = [p[0] for p in points], [p[1] for p in points]
plt.scatter(X, Y)
plt.scatter([point[0]], [point[1]], c='r')
plt.show()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章