書接上文,我們這一張來討論KD Tree的最近鄰搜索。首先來看一下《統計學習方法》書中給出的算法描述:
我們用網上很常見的例子進行分析:
(7,2),(5,4),(9,6),(2,3),(4,7),(8,1)
待測節點:(2,4.5)
我們先不管在座標上是怎麼劃分的,先看看在KD Tree上是如何遍歷的。
我們先來進行分析:
(1)在kd樹中找出包含目標點x的葉節點。
這一步比較簡單,首先根據樹的每一層分割所用的維度(Flag)進行搜索,直到找到葉子節點位置(Position),搜索路徑圖表示如下:
只需要比較該維度座標值和樹節點上該維度座標值大小即可:
1>比較目標節點(node)和樹上節點(TNode)對應維度(dia)上座標值的大小:
1.1>如果node.get(dia) <= TNode.get(dia):遞歸搜索左孩子。
1.2>否則遞歸搜索右孩子。
2>當搜索到葉子節點時,返回之並退出。
private KDTreeNode GetLeaf(KDTreeNode Tnode, KDNode node) {
int dia = Tnode.flag%KDNode.dimension;
if (Tnode.getValue() == null) { return Tnode.getFather();
//如果找到葉子節點還未找到,那就把他父節點設爲最近鄰
} else if (node.get(dia) < Tnode.getValue().get(dia)) {
return GetLeaf(Tnode.getLeft(), node);
}else
return GetLeaf(Tnode.getRight(), node);
}
2)以此葉子節點爲“當前最近點“(CurNearest)。
3)遞歸向上會退,在每個節點進行如下操作:
a)如果該節點保存的實例點比當前最近最近點距離目標節點更近,則更新“當前最近點“。
b)當前最近點一定存在於該節點(SplitPotint)的一個子節點對應的區域。檢查該子節點的父節點的另一子節點對應的區域是否有更近的節點。具體的檢查另一子節點對應的區域是否與以目標點爲球心,目標點到當前最近點距離爲半徑的超球相交。
b.1>如果相交,則可能在另一子節點對應的區域存在距離目標點更近的點,移動到另一子節點,遞歸搜索。
b.2>如果不相交,向上退回。
4>當退回到跟節點時,搜索結束,返回當前最近點。
這段描述是該算法的核心與難點,實現過程中遇到衆多難題,其中之一便是“向上退回“時進入死循環。咱們先看描述,再來說我遇到的問題。
其實並不難理解,
第一步:從一個葉子節點向上搜索,如果父節點離node近,則更新之。
第二步:比較當前的兩點(Curnearest&node)間距離和node到超平面的距離的大小(是否相交),
- 如果不相交,那很好說,直接搜索父節點即可。
- 如果相交,問題來了,就需要進入其兄弟節點看看有沒有比CurNearest更近的節點。如果有,則更新之,如果沒有則返回父節點。
-注意,此時返回父節點後,可知該父節點的兩個子節點都沒有比CurNearest更近的節點了,此時向上退回到爺爺節點。我在這一步的實現上出現了問題。問題就是在遞歸調用時,無法判斷該節點是退回到的節點還是向下搜索到的節點,也就是說無法判斷該節點的兩個子節點是否都已經被搜索過了。即如下圖所示,無法區分是綠線搜索到還是藍線搜索到,後面我來說一下我的解決方法。
我使用了路徑棧(pathStack)的方式存儲所有搜索過的節點,這樣判斷該節點是否在棧中就可以區分啦!
- 使用方式:搜素一個節點,首先判斷該節點是否在棧中,如果在則說明是回退搜索到的,則將該節點彈出棧,並回退到其父節點處。如果不在棧中,則說明是向下搜索到的該節點,則入棧該節點,並進行a,b操作。
實現代碼如下:
KDTreeNode GetNearest(KDNode node, KDTreeNode nearest, KDTreeNode spiltPoint) {
KDTreeNode CurNearest = nearest;
if (!pathStack.contains(spiltPoint)) {//如果該節點未被便利過,則進行遍歷,避免父子之間死循環
if (spiltPoint == null)
return nearest;
// if (spiltPoint == root) {//遍歷終止
// if (node.distance(nearest.getValue()) > node.distance(spiltPoint.getValue()))
// return root;//如果root節點比當前最近節點近則返回root節點
// else //否則返回當前最近節點
// return nearest;
// }
pathStack.push(spiltPoint);
if (node.distance(nearest.getValue()) > node.distance(spiltPoint.getValue())) {
CurNearest = spiltPoint;
}//如果當前節點距離待測數據比較近,更新之。
if (node.isTangential(CurNearest.getValue(), spiltPoint.getValue(), spiltPoint.getFlag())) {
//發生相交的情況
//去查找最鄰近節點的兄弟節點
//if (spiltPoint.getLeft() == nearest && spiltPoint.getRight().getValue() != null) {
if(nearest.getValue().get(spiltPoint.getFlag()%KDNode.dimension) < spiltPoint.getValue().get(spiltPoint.getFlag()%KDNode.dimension)
&& spiltPoint.getRight().getValue() != null){
return GetNearest(node, CurNearest, spiltPoint.getRight());
} //else if (spiltPoint.getRight() == nearest && spiltPoint.getLeft().getValue() != null) {
else if (nearest.getValue().get(spiltPoint.getFlag()%KDNode.dimension) >= spiltPoint.getValue().get(spiltPoint.getFlag()%KDNode.dimension)
&& spiltPoint.getRight().getValue() != null) {
return GetNearest(node, CurNearest, spiltPoint.getLeft());
}
//當不屬於兄弟節點時,查找所有子節點,類似於先序遍歷
else {
if (spiltPoint.getRight().getValue() != null) {
return GetNearest(node, CurNearest, spiltPoint.getRight());
}
if (spiltPoint.getLeft().getValue() != null) {
return GetNearest(node, CurNearest, spiltPoint.getLeft());
}
}
}
}
//遞歸遍歷父節點
pathStack.pop();
return GetNearest(node, CurNearest, spiltPoint.getFather());
}
代碼中遍歷終止處我更改了方式,將到root終止變爲了null終止,原因是當我用root終止時,就不能判斷是否與root超平面相交,也不搜索root的右子樹了。
可見我使用遞歸調用的方式實現的。這種實現方法並不好,因爲當數據量極大時,會出現棧溢出的情況,爲什麼我會知道呢?因爲實現過程中,我的遞歸終止條件有問題,報了這個錯,所以說明會出現這種情況。除此之外,在學習C語言的時候我們知道,每次遞歸調用函數時,都會新開闢一塊空間,與循環相比開銷較大,效率低,易讀性差。然而臨近考試我也沒時間改了,就先這樣吧。T^T
這一部分主要學習了最近鄰的實現,下一章我們學習一下K近鄰好啦~
爲了這一部分的方便,我添加了KDNode,KDNodeSet,KDTreeNode,KDTree中的方法,更新後的相關類內容如下:
KDNode:
class KDNode {//每一個數據
public static int dimension = 0;
float[] coordinate;
KDNode() {
if (this.dimension != 0)
coordinate = new float[dimension];
else
coordinate = null;
}
KDNode(int dimension) {
if (this.dimension == 0) {
this.dimension = dimension;
coordinate = new float[dimension];
} else {
if (this.dimension != this.dimension) {
System.err.println("不允許更改節點維度!");
}
}
}
void set(int pos, float val) { coordinate[pos] = val; }
float get(int pos) { return coordinate[pos]; }
/** 求距離 */
float distance(KDNode node) {
int dis = 0;
for (int i=0; i<dimension; i++) {
dis += Math.pow(coordinate[i] - node.get(i), 2);
}
return dis;
}
/** 判斷是否相交 */
boolean isTangential(KDNode NearestNode, KDNode FlagNode, int dim) {
//NearestNode是最近鄰,flag node是欲判斷是否相切的超平面上的點,dia是該超平面用於切割的維數
float radius;//radius是半徑,以節點爲圓心,到最近鄰的距離。
radius = distance(NearestNode);
return isTangential(radius, FlagNode, dim);
}
boolean isTangential(float radius, KDNode FlagNode, int dim) {
//radius是半徑,以節點爲圓心,到最近鄰的距離,flag node是欲判斷是否相切的超平面上的點,dia是該超平面用於切割的維數
int dis = 0;
for (int i=0; i<dimension; i++) if (dim == i) continue;
if (dis >= radius)
return false;
else
return true;
}
@Override
public String toString() {
String s = " ";
for (int i=0; i<dimension; i++)
s += (" " + coordinate[i]);
return s;
}
}
KDNodeSet:
class KDNodeSet {//用來存節點的集合,
ArrayList<KDNode> set;
int midPos;
KDNodeSet() {
set = new ArrayList<KDNode>();
}
KDNodeSet(KDNodeSet Nodeset) {
this.set = Nodeset.set;
}
void add(KDNode node) {
set.add(node);
}
KDNode findMiD(int flag) {
midPos = set.size()/2 ;
return findMiD(0, set.size(), flag);
}
private KDNode findMiD(int begin, int end, int flag) {
if (begin >= end) {return null;}
KDNode lastNode = set.get(end-1);
int dia = flag%(KDNode.dimension);
double keyValue = lastNode.get(dia);
int LastSmall = begin-1;
for (int i=begin; i<end-1; i++) {
if (set.get(i).get(dia) < keyValue) {
exchange(++LastSmall, i);
}
}
exchange(end-1, ++LastSmall);
if (midPos == LastSmall)
return set.get(midPos);
else if (midPos < LastSmall)
return findMiD(begin, LastSmall, flag);
else
return findMiD(LastSmall+1, end, flag);
}
KDNodeSet findLeft() {
return getSubSet(0, midPos);
}
KDNodeSet findRight() {
return getSubSet(midPos+1, set.size());
}
KDNodeSet getSubSet(int begin, int end) {
KDNodeSet subSet = new KDNodeSet();
for (int i=begin; i<end; i++)
subSet.add(set.get(i));
return subSet;
}
private void exchange(int pos1, int pos2) {
KDNode temp = set.get(pos1);
set.set(pos1, set.get(pos2));
set.set(pos2, temp);
}
public KDNode findNearest(KDNode node, float distance) {
KDNode ans = null;
for (Iterator<KDNode>ip = set.iterator(); ip.hasNext(); ip.next()) {
if (node.distance((KDNode) ip) < distance) {
ans = (KDNode) ip;
distance = node.distance((KDNode) ip);
}
}
return ans;
}
public static void main(String args[]) {
KDNodeSet set = new KDNodeSet();
for (int i=10; i>=0; i--) {
KDNode node = new KDNode();
node.set(0, i);
set.add(node);
}
System.out.println(set.findMiD(1));
System.out.println(set);
}
}
KDTreeNode:
class KDTreeNode {//KD樹的每一個節點,保存有中位數的值,在該層有的集合和父子節點引用
int flag = 0;
KDNode value;
KDTreeNode father;
KDNodeSet set;
KDTreeNode left;
KDNodeSet leftSet;
KDTreeNode right;
KDNodeSet rightSet;
KDTreeNode(KDTreeNode father, KDNodeSet set) {
this.father = father;
this.set = set;
run();
}
KDTreeNode(KDTreeNode father, KDNodeSet set, int flag) {
this.father = father;
this.set = set;
this.flag = flag;
run();
}
void run() {
if (set.set.size() == 0) return;
value = set.findMiD(flag);
leftSet = set.findLeft();
rightSet = set.findRight();
left = new KDTreeNode(this, leftSet, flag+1);
right = new KDTreeNode(this, rightSet, flag+1);
}
KDTreeNode getFather() {return father;}
KDTreeNode getLeft() {return left;}
KDTreeNode getRight() {return right;}
KDNode getValue() {return value;}
int getFlag() {return flag;}
}
KDTree:
public class KDTree {
KDNodeSet set;
KDTreeNode root;
static Stack<KDTreeNode> pathStack;
KDTree() {
set = new KDNodeSet();
pathStack = new Stack<KDTreeNode>();
}
void addNode(KDNode node) {
set.add(node);
}
void BuildTree() {
root = new KDTreeNode(null, set);//已經建好的KDTree
}
void find(KDNode node) {
KDTreeNode position = GetLeaf(root,node);//找到最近鄰的葉子節點。
KDTreeNode NearestNode = GetNearest(node, position, position.getFather());
System.out.println(NearestNode.getValue());
}
private KDTreeNode GetLeaf(KDTreeNode Tnode, KDNode node) {
int dia = Tnode.flag%KDNode.dimension;
if (Tnode.getValue() == null) {
return Tnode.getFather();//返回葉子節點
} else if (node.get(dia) < Tnode.getValue().get(dia)) {
return GetLeaf(Tnode.getLeft(), node);
}else
return GetLeaf(Tnode.getRight(), node);
}
KDTreeNode GetNearest(KDNode node, KDTreeNode nearest, KDTreeNode spiltPoint) {
KDTreeNode CurNearest = nearest;
if (!pathStack.contains(spiltPoint)) {//如果該節點未被便利過,則進行遍歷,避免父子之間死循環
if (spiltPoint == null)
return nearest;
// if (spiltPoint == root) {//遍歷終止
// if (node.distance(nearest.getValue()) > node.distance(spiltPoint.getValue()))
// return root;//如果root節點比當前最近節點近則返回root節點
// else //否則返回當前最近節點
// return nearest;
// }
pathStack.push(spiltPoint);
if (node.distance(nearest.getValue()) > node.distance(spiltPoint.getValue())) {
CurNearest = spiltPoint;
}//如果當前節點距離待測數據比較近,更新之。
if (node.isTangential(CurNearest.getValue(), spiltPoint.getValue(), spiltPoint.getFlag())) {
//發生相交的情況
//去查找最鄰近節點的兄弟節點
//if (spiltPoint.getLeft() == nearest && spiltPoint.getRight().getValue() != null) {
if(nearest.getValue().get(spiltPoint.getFlag()%KDNode.dimension) < spiltPoint.getValue().get(spiltPoint.getFlag()%KDNode.dimension)
&& spiltPoint.getRight().getValue() != null){
return GetNearest(node, CurNearest, spiltPoint.getRight());
} //else if (spiltPoint.getRight() == nearest && spiltPoint.getLeft().getValue() != null) {
else if (nearest.getValue().get(spiltPoint.getFlag()%KDNode.dimension) >= spiltPoint.getValue().get(spiltPoint.getFlag()%KDNode.dimension)
&& spiltPoint.getRight().getValue() != null) {
return GetNearest(node, CurNearest, spiltPoint.getLeft());
}
//當不屬於兄弟節點時,查找所有子節點,類似於先序遍歷
else {
if (spiltPoint.getRight().getValue() != null) {
return GetNearest(node, CurNearest, spiltPoint.getRight());
}
if (spiltPoint.getLeft().getValue() != null) {
return GetNearest(node, CurNearest, spiltPoint.getLeft());
}
}
}
}
//遞歸遍歷父節點
pathStack.pop();
return GetNearest(node, CurNearest, spiltPoint.getFather());
}
public static void main(String[] args) {
KDTree tree = new KDTree();
KDNode node1 = new KDNode(2);
node1.set(0,2);
node1.set(1,3);
tree.addNode(node1);
node1 = new KDNode();
node1.set(0,5);
node1.set(1,4);
tree.addNode(node1);
node1 = new KDNode();
node1.set(0,9);
node1.set(1,6);
tree.addNode(node1);
node1 = new KDNode();
node1.set(0,4);
node1.set(1,7);
tree.addNode(node1);
node1 = new KDNode();
node1.set(0,8);
node1.set(1,1);
tree.addNode(node1);
node1 = new KDNode();
node1.set(0,7);
node1.set(1,2);
tree.addNode(node1);
tree.BuildTree();
node1 = new KDNode();
node1.set(0, (float) 2.1);
node1.set(1, (float) 3.1);
tree.find(node1);
node1 = new KDNode();
node1.set(0, (float) 2);
node1.set(1, (float) 4.5);
tree.find(node1);
}
}