Kd-tree原理與實現

數據應用當中,最近鄰查詢是非常重要的功能。不論是信息檢索,推薦系統,還是數據庫查詢,最近鄰查詢(Nearst Neighbor Search)可謂無處不在。它要實現的是幫助我們找到數據中和查詢最接近的一個或多個數據條目(前者叫NN search, 後者也叫kNN),其實本質上是一樣的,我在這篇博客中講的Kd-tree主要就是針對這種最近鄰搜索問題。

1. 基本原理

其實,這種問題本來是很容易解決的,只要設計好了數據相似度的度量方法(有關相似度量的方法詳細可參考我之前的博客:數據相似性的度量方法總結)計算所有數據與查詢的距離,比較大小即可。但是隨着數據量的增大以及數據維度的提高,這種方法就很難在現實中應用了,因爲效率會非常低。解決此類問題的思路基本分爲兩類:
(1)通過構建索引,快速排除與查詢相關度不大的數據;
(2)通過降維的方法,對數據條目先降維,再查詢;
前者主要是爲了解決數據量過大的問題,比較常見的有我們熟知的二叉搜索樹,Merkel tree,B-tree,quad-tree等;後者主要是爲了解決維度過大的問題,比較常見的方法有我在上一篇博客中講的LSH:LSH(Locality Sensitive Hashing)原理與實現

而我們今天要說的Kd-tree就是一種對多維歐式空間分割,從而構建的索引,屬於上面的第一類。

Kd-tree全稱叫做:k dimension tree,這是一種對於多維歐式空間分割構造的的二叉樹,其性質非常類似於二叉搜索樹。我們先回顧一下二叉搜索樹,它是一種具有如下特徵的二叉樹:
(1)若它的左子樹不爲空,則左子樹上所有結點的值均小於它的根結點的值;
(2)若它的右子樹不爲空,則右子樹上所有結點的值均大於它的根結點的值;
(3)它的左、右子樹也分別爲二叉搜索樹;
這個概念是數據結構基礎的東西,應該非常熟悉了,不再贅述,下面給出一棵普通的二叉搜索樹的圖:
這裏寫圖片描述

如果我們把二叉搜索樹所對應的數據集看做一個一維空間(因爲這個數據集的每一個數據條目都是由一個單一的數值構成的),那麼實際上二叉搜索樹的分割依據就是數值的大小,這樣的劃分,幫助我們以平均O(lg(n))的時間複雜度搜索數據。

自然而然,我們會祥這樣一個問題,能不能在多維歐式空間中,構建一棵類似原理的二叉搜索樹?這也就是我們今天說的Kd-tree.

2. kd-tree的構建

先拋開搜索算法怎樣設計這件事不管,我們單純地關心怎樣對多維歐式空間劃分。一維空間簡單,因爲每個數據條目只有一個數值,我們直接比較數值大小,就能對這些數據條目劃分,可是在多維空間就存在一個關鍵問題:每個數據條目由多個數值組成,我們怎麼比較?

Kd-tree的原理是這樣的:我們不比較全部的k維數據,而是選擇其中某一個維度比較,根據這個維度進行空間劃分。那接下來,我們需要做的是兩件事:

  • 判斷出在哪一個維度比較,也就是說,我們所要切割的面在哪一個維度上。當然這種切割需要遵循一個基本要求,那就是儘量通過這個維度的切割,使得數據集均分(爲二);
  • 判斷以哪個數據條目分依據劃分。上面我們說,要使得數據集均分爲二,那當然要選擇一個合適的數據項,充當這個劃分的“點”。

總結一下,就是要選擇一個數據項,以這個數據項的某個維度的值爲標準,同一維度的值大於這個值的數據項,劃分爲一部分,小於的劃分爲另一部分。根據這種劃分來構建二叉樹,就如同二叉搜索樹那樣。

現在,針對上面的兩件事,我們需要做如下兩個工作:
1. 確定劃分維度:這裏維度的確定需要注意的是儘量要使得這個維度上所有數據項數值的分佈儘可能地有大方差,也就是說,數據在這個維度上儘可能分散。這就好比是我們切東西,如果你切的是一根黃瓜,當讓橫着切要比豎着切更容易。所以我們應該先對所有維度的數值計算方差,選擇方差最大的那個維度;
2. 選擇充當切割標準的數據項:那麼只需要求得這個維度上所有數值的中位數即可;

至此,可以設計出kd-tree的構建算法了:

  • 對於一個由n維數據構成的數據集,我們首先尋找方差最大的那個維度,設這個維度是d ,然後找出在維度d 上所有數據項的中位數m ,按m 劃分數據集,一分爲二,記這兩個數據子集爲Dl,Dr 。建立樹節點,存儲這次劃分的情況(記錄劃分的維度d 以及中位數m );
  • Dl,Dr 重複進行以上的劃分,並且將新生成的樹節點設置爲上一次劃分的左右孩子;
  • 遞歸地進行以上兩步,直到不能再劃分爲止(所謂不能劃分是說當前節點中包含的數據項的數量小於了我們事先規定的閾值,不失一般性,我在此篇博客中默認這個閾值是2,也就是說所有葉子節點包含的數據項不會多於2條),不能再劃分時,將對應的數據保存至最後的節點中,這些最後的節點也就是葉子節點。

現在可以給出kd-tree的實現代碼。當然,首先需要設計幾個函數,供算法調用,限於篇幅,這裏只是給出功能說明:

類或函數 作用
class-KdTreeNode kd-tree節點,包含以下6個Attributes
Attribute1-data 樹節點屬性,代表這個節點的數據項,其實是一個列表,如果不是葉子節點,則爲空
Attribute2-split 樹節點屬性,代表構建樹時,對這個節點進行分割所依據的數據維度
Attribute3-median 樹節點屬性,代表構建樹時,所有上面split維度上數據的中位數
Attribute4-left 樹節點屬性,代表左孩子
Attribute5-right 樹節點屬性,代表右孩子
Attribute6-parent 樹節點屬性,代表父親節點,作用是在後面的搜索算法中用
Attribute7-visited 樹節點屬性,代表此節點是否被算法回溯遍歷,作用是在後面的搜索算法中用
func-getSplit 函數,得到所有維度中方差最大那個維度的序號
func-getMedian 函數,得到要分割的維度的中位數

按照上面這樣設計,就可以實現kd-tree的構建了。我們這裏使用numpy庫,假設現在已經將所有的數據項讀入爲一個ndarray型的數據矩陣datamatrixdatamatrix的每一行代表了一個數據項。那麼構建樹算法的實現代碼可以如下所示:

import numpy as np

# 樹節點類和其相關方法如下
class KdTreeNode(object):

    def __init__(self, dataMatrix):

        self.data = dataMatrix

        self.left, self.right = None, None
        self.parent = None

        self.split = self.getSplit()
        self.median = self.getMedian()

        self.visited = False

    def getSplit(self):# 取方差最大的維度作爲分割維度,代碼略

    def getMedian(self):# 得到這個分割維度上所有數值的中位數,代碼略

# 構建kd-tree的函數,helper爲其輔助函數,起到遞歸的作用
def buildKdTree(dataMatrix):

    root = KdTreeNode(dataMatrix)

    # there is only one data item in dataMatrix
    if root.data.shape[0] <= 1:
        return root

    helper(root)
    return root


def helper(root):

    if root is None or len(root.data) <= 2:
        return

    # distribute data into left and right
    leftData, rightData = [], []

    # generate left and right child
    for row in list(root.data):
        if row[root.split] <= root.median:
            leftData.append(row)
        else:
            rightData.append(row)

    left = KdTreeNode(np.array(leftData))
    left.parent = root

    right = KdTreeNode(np.array(rightData))
    right.parent = root

    root.data = None
    root.left = left
    root.right = right

    helper(root.left)
    helper(root.right)

我在這裏,借用博客Kd-Tree算法原理和開源實現代碼中的測試樣例:數據集合(2,3), (5,4), (9,6), (4,7), (8,1), (7,2),按照以上算法原理設計的kd-tree以及劃分情況如以下兩張圖所示:我在這裏直接借用了上面這個鏈接中博客的圖,這位博主的文章思路寫的非常清晰。

這裏寫圖片描述
圖中,非葉節點的二元組中,第一個元素表示分割維度(split值),第二個維度表示,取得的中位數(median值)

3. 搜索算法

構建好kd-tree後,就可以執行搜索算法了。其實,這也是信息檢索最常見的模式,先構建索引,然後依照索引執行搜索算法。當然幾乎所有的搜索算法都與其索引是配套的,也就是說,即便是同樣的數據,索引不同,其搜索算法就不同,而各有各的技巧。這也是信息檢索技術最大的魅力之一。

閒話少說,看搜索算法。基本思路可分爲如下3步:

  1. 依照非葉節點中存儲的分割維度以及中位數信息,自根節點始,從上向下搜索,直到到達葉子。遍歷的原則當然是比較分割維度上,查詢值與中位數的大小,設查詢爲Q,當前遍歷到的節點爲u,則若Q[u.split] > u.median,繼續遍歷u的右子樹,反之,遍歷左子樹。
  2. 遍歷到葉子之後,計算葉子節點中與查詢Q距離最小的數據項與查詢的距離,記爲minDis;其後執行“回溯”操作,回溯至當前節點的父節點,判斷以Q爲球心,以minDis爲半徑的超球面是否與這個父節點的另一個分支所代表的區域有交集(其實,這裏的區域就是一個超矩形,它包含了所有這個節點代表的數據項)。如果沒有,繼續向上一層回溯;如果有,則按照1步繼續執行,探底到葉子節點後,如果此時Q與這個葉子節點中的數據項有更小的距離,則更新minDis
  3. 持續進行以上兩步,直到回溯至根節點,且根節點的兩個分支都被“探測”過爲止。

但是這個裏面有一個難點:如何判斷以查詢Q爲球心,以當前的minDis爲半徑的超球面與樹中,一個非葉節點所代表的超矩形是否相交?
一種簡單的方法是在構建樹的時候直接給每個節點賦值一個超矩形,這個超矩形以一個樹節點屬性的形式存在。一般情況下是給出超矩形的一個最大點和一個最小點。判斷的方法只需要看如下的兩個條件是否都成立即可:

  • Q[u.split] + minDis >= minPoint[u.split]
  • Q[u.split] - minDis >= maxPoint[u.split]

其中,u爲查詢當前遍歷到的節點的父節點,minPoint與maxPoint爲u所代表的超矩形的最大點和最小點(所謂最大最小點,那二維空間的矩形來說,就是他的右上角的點和左下角的點,分別擁有這個矩形範圍內各個維度上的最大值和最小值)

原因很簡單,因爲以Q爲球心,以當前這個矩形區域的一個點爲球面上一點的一個超球面,一定是經過了當前這個葉子所代表的區域,但是同時它不可能完全覆蓋他的兄弟節點代表的區域。這個道理聽上去有點亂,看下面這個圖就能明白:


圖中,Q1,Q2,Q3是三個查詢點,線段AB是這個矩形空間的分割情況。可見,上面的結論書成立的,同時,我們還可以得到一個觀點:只要|Q[u.split] - u.median|<= minDis那麼就是與其兄弟節點所代表的區域相交。其實這個道理也可以通過數學上的推導得到,如果不能理解的話一試便知。

說道這裏,可以給出搜索算法的實現代碼了:

import math

# 計算兩個多維向量的歐式距離
def dis(item, query):代碼略

# 回溯,找尋需要處理的下一個節點,下一節點應滿足不曾被算法回溯遍歷
def findNextNode(cur):代碼略

# 判斷以查詢爲球心,以此時的最小距離minDis爲半徑的超球面是否與節點所代表的超矩形相交
def intersect(node, query, radius):代碼略

# 找到節點的兄弟節點
def getBrother(node):代碼略


def search(root, query, result, minDis):

    cur = root

    # the root is None
    if not cur:
        return result

    # find leaf
    elif not cur.visited:
        while cur.left and cur.right:
            if query[cur.split] >= cur.median:
                cur = cur.right
            else:
                cur = cur.left

        # update the min dis if it is necessary
        for item in list(cur.data):
            tempDis = dis(item, query)
            if abs(tempDis - minDis) < 1e-9:
                result.append(list(item))
            elif tempDis < minDis:
                minDis = tempDis
                result = [list(item)]


        # update the visited
        cur.visited = True

        # process the next node
        cur = findNextNode(cur)
        if intersect(cur, query, minDis):
            return search(cur, query, result, minDis)
        else:
            cur.visited = True
            nextNode = findNextNode(cur)
            return search(nextNode, query, result, minDis)
    else:
        return result

依照算法的設計,我們以上面的kd-tree的圖爲例,可以看看搜索算法遍歷的順序:

  1. 查詢點(8, 3)自根節點起,按照分割維度以及中位數向下遍歷,找到葉子節點(9, 6),此時算得的最小距離爲10
  2. 回溯,找到下一個需要處理的節點,也就是(8,1), (7,2)這個點(此時以(8,3)爲圓心,以10 爲半徑的圓與這個點所代表區域相交),數據項 (7,2)與查詢(8, 3)的距離更近,爲2 ,更新最小距離爲2
  3. 回溯,此時,非葉節點<2, 2>這個點所在的分支已經被訪問過了,找到下一個需要處理的節點,<2, 4>這個點。不過計算距離發現,這個點所代表的區域並不與此時的圓相交,放棄對這一分支的搜索;
  4. 回溯至根節點,並且此時根節點的兩個分支都被考慮了,搜索結束,返回最近鄰(7, 2),最短距離是2

以上就是全部kd-tree的原理以及對應搜索算法的實現。內容我大多參考了博客:Kd-Tree算法原理和開源實現代碼
限於篇幅,本篇博客並未給出全部的詳細代碼,若要參考,請查看我的github主頁:KD-tree

不足之處,還望指正。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章