導語 :在上一篇《kd 樹算法之思路篇》中,我們介紹瞭如何用二叉樹格式記錄空間內的距離,並以其爲依據進行高效的索引。在本篇文章中,我們將詳細介紹 kd 樹的構造以及 kd 樹上的 kNN 算法。
作者:肖睿
編輯:宏觀經濟算命師
本文由 JoinQuant 量化課堂推出,本文的難度屬於進階(下),深度爲 level-1
閱讀本文前請掌握 kNN(level-1)的知識。
KD 樹的結構
kd 樹是一個二叉樹結構,它的每一個節點記載了【特徵座標,切分軸,指向左枝的指針,指向右枝的指針】。
其中,特徵座標是線性空間 ℝnℝnRn 中的一個點 (x1,x2,…,xn)(x1,x2,…,xn)(x1,x2,…,xn)。
切分軸由一個整數 rrr 表示,這裏 1≤r≤n1≤r≤n1≤r≤n,是我們在 nnn 維空間中沿第 rrr 維進行一次分割。
節點的左枝和右枝分別都是 kd 樹,並且滿足:如果 yyy 是左枝的一個特徵座標,那麼 yr≤xryr≤xryr≤xr;並且如果 zzz 是右枝的一個特徵座標,那麼 zr≥xrzr≥xrzr≥xr。
給定一個數據樣本集 S⊆RnS⊆RnS⊆Rn 和切分軸 rrr,以下遞歸算法將構建一個基於該數據集的 kd 樹,每一次循環制作一個節點:
−−− 如果 |S|=1|S|=1|S|=1,記錄 SSS 中唯一的一個點爲當前節點的特徵數據,並且不設左枝和右枝。(|S||S||S| 指集合 SSS 中元素的數量)
−−− 如果 |S|>1|S|>1|S|>1:
∙∙∙ 將 SSS 內所有點按照第 rrr 個座標的大小進行排序;
∙∙∙ 選出該排列後的中位元素(如果一共有偶數個元素,則選擇中位左邊或右邊的元素,左或右並無影響),作爲當前節點的特徵坐 標,並且記錄切分軸 rrr;
∙∙∙ 將 SLSLSL 設爲在 SSS 中所有排列在中位元素之前的元素; SRSRSR 設爲在 SSS 中所有排列在中位元素後的元素;
∙∙∙ 當前節點的左枝設爲以 SLSLSL 爲數據集並且 rrr 爲切分軸製作出的 kd 樹;當前節點的右枝設爲以 SRSRSR 爲數據集並且 rrr 爲切分軸製作出 的 kd 樹。再設 r←(r+1)modnr←(r+1)modnr←(r+1)modn。(這裏,我們想輪流沿着每一個維度進行分割;modnmodnmodn 是因爲一共有 nnn 個維度,在 沿着最後一個維度進行分割之後再重新回到第一個維度。)
構造 KD 樹的例子
上面抽象的定義和算法確實是很不好理解,舉一個例子會清楚很多。首先隨機在 ℝ2ℝ2R2 中隨機生成 13 個點作爲我們的數據集。起始的切分軸 r=0r=0r=0;這裏 r=0r=0r=0 對應 xxx 軸,而 r=1r=1r=1 對應 yyy 軸。
首先先沿 xxx 座標進行切分,我們選出 xxx 座標的中位點,獲取最根部節點的座標
並且按照該點的 x 座標將空間進行切分,所有 xxx 座標小於 6.276.276.27 的數據用於構建左枝,xxx 座標大於 6.276.276.27 的點用於構建右枝。
在下一步中 r=0+1=1mod2r=0+1=1mod2r=0+1=1mod2 對應 yyy 軸,左右兩邊再按照 yyy 軸的排序進行切分,中位點記載於左右枝的節點。得到下面的樹,左邊的xxx 是指這該層的節點都是沿 xxx 軸進行分割的。
空間的切分如下
下一步中 r≡1+1≡0mod2r≡1+1≡0mod2r≡1+1≡0mod2,對應 xxx 軸,所以下面再按照 xxx 座標進行排序和切分,有
最後每一部分都只剩一個點,將他們記在最底部的節點中。因爲不再有未被記錄的點,所以不再進行切分。
就此完成了 kd 樹的構造。
KD 樹上的 KNN 算法
給定一個構建於一個樣本集的 kd 樹,下面的算法可以尋找距離某個點 ppp 最近的 kkk 個樣本。
零、設 LLL 爲一個有 kkk 個空位的列表,用於保存已搜尋到的最近點。
一、根據 ppp 的座標值和每個節點的切分向下搜索(也就是說,如果樹的節點是按照 xr=axr=axr=a 進行切分,並且 ppp 的 rrr 座標小於 aaa,則向左枝 進行搜索;反之則走右枝)。
二、當達到一個底部節點時,將其標記爲訪問過。如果 LLL 裏不足 kkk 個點,則將當前節點的特徵座標加入 LLL ;如果 LLL 不爲空並且當前節點 的特徵與 ppp 的距離小於 LLL 裏最長的距離,則用當前特徵替換掉 LLL 中離 ppp 最遠的點。
三、如果當前節點不是整棵樹最頂端節點,執行 (a);反之,輸出 LLL,算法完成。
a.a.a. 向上爬一個節點。如果當前(向上爬之後的)節點未曾被訪問過,將其標記爲被訪問過,然後執行 (1) 和 (2);如果當前節點被訪 問過,再次執行 (a)。
1.1.1. 如果此時 LLL 裏不足 kkk 個點,則將節點特徵加入 LLL;如果 LLL 中已滿 kkk 個點,且當前節點與 ppp 的距離小於 LLL 裏最長的距離, 則用節點特徵替換掉 LLL 中離最遠的點。
2.2.2. 計算 ppp 和當前節點切分線的距離。如果該距離大於等於 LLL 中距離 ppp 最遠的距離 並且 LLL 中已有 kkk 個點,則在切分線另一邊不會有更近的點,執行 (三);如果該距離小於 LLL 中最遠的距離 或者 LLL 中不足 kkk 個點,則切分線另一邊可能有更近的點,因此在當前節點的另一個枝從 (一) 開始執行。
啊呃… 被這算法噎住了,趕緊喝一口下面的例子
設我們想查詢的點爲 p=(−1,−5)p=(−1,−5)p=(−1,−5),設距離函數是普通的 L2L2L2 距離,我們想找距離問題點最近的 k=3k=3k=3 個點。如下:
首先執行 (一),我們按照切分找到最底部節點。首先,我們在頂部開始
和這個節點的 xxx 軸比較一下,
ppp 的 xxx 軸更小。因此我們向左枝進行搜索:
這次對比 yyy 軸,
ppp 的 yyy 值更小,因此向左枝進行搜索:
這個節點只有一個子枝,就不需要對比了。由此找到了最底部的節點 (−4.6,−10.55)(−4.6,−10.55)(−4.6,−10.55)。
在二維圖上是
此時我們執行 (二)。將當前結點標記爲訪問過,並記錄下 L=[(−4.6,−10.55)]L=[(−4.6,−10.55)]L=[(−4.6,−10.55)]。啊,訪問過的節點就在二叉樹上顯示爲被劃掉的好了。
然後執行 (三),嗯,不是最頂端節點。好,執行 (a),我爬。上面的是 (−6.88,−5.4)(−6.88,−5.4)(−6.88,−5.4)。
執行 (1),因爲我們記錄下的點只有一個,小於 k=3k=3k=3,所以也將當前節點記錄下,有 L=[(−4.6,−10.55),(−6.88,−5.4)]L=[(−4.6,−10.55),(−6.88,−5.4)]L=[(−4.6,−10.55),(−6.88,−5.4)]。再執行 (2),因爲當前節點的左枝是空的,所以直接跳過,回到步驟 (三)。(三) 看了一眼,好,不是頂部,交給你了,(a)。於是乎 (a) 又往上爬了一節。
(1) 說,由於還是不夠三個點,於是將當前點也記錄下,有 L=[(−4.6,−10.55),(−6.88,−5.4),(1.24,−2.86)]L=[(−4.6,−10.55),(−6.88,−5.4),(1.24,−2.86)]L=[(−4.6,−10.55),(−6.88,−5.4),(1.24,−2.86)]。當然,當前結點變爲被訪問過的。
(2) 又發現,當前節點有其他的分枝,並且經計算得出 ppp 點和 LLL 中的三個點的距離分別是 6.62,5.89,3.106.62,5.89,3.106.62,5.89,3.10,但是 ppp 和當前節點的分割線的距離只有 2.142.142.14,小於與 LLL 的最大距離:
因此,在分割線的另一端可能有更近的點。於是我們在當前結點的另一個分枝從頭執行 (一)。好,我們在紅線這裏:
要用 ppp 和這個節點比較 xxx 座標:
ppp 的 xxx 座標更大,因此探索右枝 (1.75,12.26)(1.75,12.26)(1.75,12.26),並且發現右枝已經是最底部節點,因此啓動 (二)。
經計算,(1.75,12.26)(1.75,12.26)(1.75,12.26) 與 ppp 的距離是 17.4817.4817.48,要大於 ppp 與 LLL 的距離,因此我們不將其放入記錄中。
然後 (三) 判斷出不是頂端節點,呼出 (a),爬。
(1) 出來一算,這個節點與 ppp 的距離是 4.914.914.91,要小於 ppp 與 LLL 的最大距離 6.626.626.62。
因此,我們用這個新的節點替代 LLL 中離 ppp 最遠的 (−4.6,−10.55)(−4.6,−10.55)(−4.6,−10.55)。
然後 (2) 又來了,我們比對 ppp 和當前節點的分割線的距離
這個距離小於 LLL 與 ppp 的最小距離,因此我們要到當前節點的另一個枝執行 (一)。當然,那個枝只有一個點,直接到 (二)。
計算距離發現這個點離 ppp 比 LLL 更遠,因此不進行替代。
(三) 發現不是頂點,所以呼出 (a)。我們向上爬,
這個是已經訪問過的了,所以再來(a),
好,(a)再爬,
啊!到頂點了。所以完了嗎?當然不,還沒輪到 (三) 呢。現在是 (1) 的回合。
我們進行計算比對發現頂端節點與 p 的距離比 L 還要更遠,因此不進行更新。
然後是 (2),計算 ppp 和分割線的距離發現也是更遠。
因此也不需要檢查另一個分枝。
然後執行 (三),判斷當前節點是頂點,因此計算完成!輸出距離 ppp 最近的三個樣本是 L=[(−6.88,−5.4),(1.24,−2.86),(−2.96,−2.5)]L=[(−6.88,−5.4),(1.24,−2.86),(−2.96,−2.5)]L=[(−6.88,−5.4),(1.24,−2.86),(−2.96,−2.5)]。
結語
kd 樹的 kNN 算法節約了很大的計算量(雖然這點在少量數據上很難體現),但在理解上偏於複雜,希望本篇中的實例可以讓讀者清晰地理解這個算法。喜歡動手的讀者可以嘗試自己用代碼實現 kd 樹算法,但也可以用現成的機器學習包 scikit-learn 來進行計算。量化課堂的 下一篇文章 就將講解如何用 scikit-learn 進行 kNN 分類。