KD tree算法(2)-最近鄰搜索KD tree

書接上文,我們這一張來討論KD Tree的最近鄰搜索。首先來看一下《統計學習方法》書中給出的算法描述:
搜索算法
這裏寫圖片描述

我們用網上很常見的例子進行分析:
(7,2),(5,4),(9,6),(2,3),(4,7),(8,1)
待測節點:(2,4.5)
我們先不管在座標上是怎麼劃分的,先看看在KD Tree上是如何遍歷的。
我們先來進行分析:

(1)在kd樹中找出包含目標點x的葉節點。

這一步比較簡單,首先根據樹的每一層分割所用的維度(Flag)進行搜索,直到找到葉子節點位置(Position),搜索路徑圖表示如下:
搜索路徑圖
只需要比較該維度座標值和樹節點上該維度座標值大小即可:
1>比較目標節點(node)和樹上節點(TNode)對應維度(dia)上座標值的大小:
1.1>如果node.get(dia) <= TNode.get(dia):遞歸搜索左孩子。
1.2>否則遞歸搜索右孩子。
2>當搜索到葉子節點時,返回之並退出。

   private KDTreeNode GetLeaf(KDTreeNode Tnode, KDNode node) {
        int dia = Tnode.flag%KDNode.dimension;
        if (Tnode.getValue() == null) { return Tnode.getFather();
        //如果找到葉子節點還未找到,那就把他父節點設爲最近鄰
        } else if (node.get(dia) < Tnode.getValue().get(dia)) {
            return GetLeaf(Tnode.getLeft(), node);
        }else
            return GetLeaf(Tnode.getRight(), node);
    }

2)以此葉子節點爲“當前最近點“(CurNearest)。
3)遞歸向上會退,在每個節點進行如下操作:
a)如果該節點保存的實例點比當前最近最近點距離目標節點更近,則更新“當前最近點“。
b)當前最近點一定存在於該節點(SplitPotint)的一個子節點對應的區域。檢查該子節點的父節點的另一子節點對應的區域是否有更近的節點。具體的檢查另一子節點對應的區域是否與以目標點爲球心,目標點到當前最近點距離爲半徑的超球相交。
b.1>如果相交,則可能在另一子節點對應的區域存在距離目標點更近的點,移動到另一子節點,遞歸搜索。
b.2>如果不相交,向上退回。
4>當退回到跟節點時,搜索結束,返回當前最近點。

這段描述是該算法的核心與難點,實現過程中遇到衆多難題,其中之一便是“向上退回“時進入死循環。咱們先看描述,再來說我遇到的問題。
其實並不難理解,
第一步:從一個葉子節點向上搜索,如果父節點離node近,則更新之。
第二步:比較當前的兩點(Curnearest&node)間距離和node到超平面的距離的大小(是否相交),

  • 如果不相交,那很好說,直接搜索父節點即可。
  • 如果相交,問題來了,就需要進入其兄弟節點看看有沒有比CurNearest更近的節點。如果有,則更新之,如果沒有則返回父節點。
    -注意,此時返回父節點後,可知該父節點的兩個子節點都沒有比CurNearest更近的節點了,此時向上退回到爺爺節點。我在這一步的實現上出現了問題。問題就是在遞歸調用時,無法判斷該節點是退回到的節點還是向下搜索到的節點,也就是說無法判斷該節點的兩個子節點是否都已經被搜索過了。即如下圖所示,無法區分是綠線搜索到還是藍線搜索到,後面我來說一下我的解決方法。
    這裏寫圖片描述
    我使用了路徑棧(pathStack)的方式存儲所有搜索過的節點,這樣判斷該節點是否在棧中就可以區分啦!
    • 使用方式:搜素一個節點,首先判斷該節點是否在棧中,如果在則說明是回退搜索到的,則將該節點彈出棧,並回退到其父節點處。如果不在棧中,則說明是向下搜索到的該節點,則入棧該節點,並進行a,b操作。

實現代碼如下:

    KDTreeNode GetNearest(KDNode node, KDTreeNode nearest, KDTreeNode spiltPoint) {
        KDTreeNode CurNearest = nearest;
        if (!pathStack.contains(spiltPoint)) {//如果該節點未被便利過,則進行遍歷,避免父子之間死循環
            if (spiltPoint == null)
                return nearest;
//            if (spiltPoint == root) {//遍歷終止
//                if (node.distance(nearest.getValue()) > node.distance(spiltPoint.getValue()))
//                    return root;//如果root節點比當前最近節點近則返回root節點
//                else //否則返回當前最近節點
//                    return nearest;
//            }
            pathStack.push(spiltPoint);
            if (node.distance(nearest.getValue()) > node.distance(spiltPoint.getValue())) {
                CurNearest = spiltPoint;
            }//如果當前節點距離待測數據比較近,更新之。

            if (node.isTangential(CurNearest.getValue(), spiltPoint.getValue(), spiltPoint.getFlag())) {
                //發生相交的情況
                //去查找最鄰近節點的兄弟節點
                //if (spiltPoint.getLeft() == nearest && spiltPoint.getRight().getValue() != null) {
                if(nearest.getValue().get(spiltPoint.getFlag()%KDNode.dimension) < spiltPoint.getValue().get(spiltPoint.getFlag()%KDNode.dimension)
                            && spiltPoint.getRight().getValue() != null){
                    return GetNearest(node, CurNearest, spiltPoint.getRight());
                } //else if (spiltPoint.getRight() == nearest && spiltPoint.getLeft().getValue() != null) {
                else if (nearest.getValue().get(spiltPoint.getFlag()%KDNode.dimension) >= spiltPoint.getValue().get(spiltPoint.getFlag()%KDNode.dimension)
                        && spiltPoint.getRight().getValue() != null) {
                    return GetNearest(node, CurNearest, spiltPoint.getLeft());
                }
                //當不屬於兄弟節點時,查找所有子節點,類似於先序遍歷
                else {
                    if (spiltPoint.getRight().getValue() != null) {
                        return GetNearest(node, CurNearest, spiltPoint.getRight());
                    }
                    if (spiltPoint.getLeft().getValue() != null) {
                        return GetNearest(node, CurNearest, spiltPoint.getLeft());
                    }
                }
            }
        }
        //遞歸遍歷父節點
        pathStack.pop();
        return GetNearest(node, CurNearest, spiltPoint.getFather());
    }

代碼中遍歷終止處我更改了方式,將到root終止變爲了null終止,原因是當我用root終止時,就不能判斷是否與root超平面相交,也不搜索root的右子樹了。
可見我使用遞歸調用的方式實現的。這種實現方法並不好,因爲當數據量極大時,會出現棧溢出的情況,爲什麼我會知道呢?因爲實現過程中,我的遞歸終止條件有問題,報了這個錯,所以說明會出現這種情況。除此之外,在學習C語言的時候我們知道,每次遞歸調用函數時,都會新開闢一塊空間,與循環相比開銷較大,效率低,易讀性差。然而臨近考試我也沒時間改了,就先這樣吧。T^T
這一部分主要學習了最近鄰的實現,下一章我們學習一下K近鄰好啦~


爲了這一部分的方便,我添加了KDNode,KDNodeSet,KDTreeNode,KDTree中的方法,更新後的相關類內容如下:

KDNode:

class KDNode {//每一個數據
    public static int dimension = 0;
    float[] coordinate;

    KDNode() {
        if (this.dimension != 0)
            coordinate = new float[dimension];
        else
            coordinate = null;
    }

    KDNode(int dimension) {
        if (this.dimension == 0) {
            this.dimension = dimension;
            coordinate = new float[dimension];
        } else {
            if (this.dimension != this.dimension) {
                System.err.println("不允許更改節點維度!");
            }
        }
    }
    void set(int pos, float val) { coordinate[pos] = val; }
    float get(int pos) { return coordinate[pos]; }

    /** 求距離 */
    float distance(KDNode node) {
        int dis = 0;
        for (int i=0; i<dimension; i++) {
            dis += Math.pow(coordinate[i] - node.get(i), 2);
        }

        return dis;
    }

    /** 判斷是否相交 */
    boolean isTangential(KDNode NearestNode,  KDNode FlagNode, int dim) {
        //NearestNode是最近鄰,flag node是欲判斷是否相切的超平面上的點,dia是該超平面用於切割的維數
        float radius;//radius是半徑,以節點爲圓心,到最近鄰的距離。
        radius = distance(NearestNode);
        return isTangential(radius, FlagNode, dim);
    }
    boolean isTangential(float radius,  KDNode FlagNode, int dim) {
        //radius是半徑,以節點爲圓心,到最近鄰的距離,flag node是欲判斷是否相切的超平面上的點,dia是該超平面用於切割的維數
        int dis = 0;
        for (int i=0; i<dimension; i++)  if (dim == i)   continue;
        if (dis >= radius)
            return false;
        else
            return true;
    }

    @Override
    public String toString() {
        String s = " ";
        for (int i=0; i<dimension; i++)
            s += ("  " + coordinate[i]);
        return s;
    }
}

KDNodeSet:

class KDNodeSet {//用來存節點的集合,
    ArrayList<KDNode> set;
    int midPos;
    KDNodeSet() {
        set = new ArrayList<KDNode>();
    }
    KDNodeSet(KDNodeSet Nodeset) {
        this.set = Nodeset.set;
    }
    void add(KDNode node) {
        set.add(node);
    }

    KDNode findMiD(int flag) {
        midPos = set.size()/2 ;
        return findMiD(0, set.size(), flag);
    }

    private KDNode findMiD(int begin, int end, int flag) {
        if (begin >= end) {return null;}
        KDNode lastNode = set.get(end-1);
        int dia = flag%(KDNode.dimension);
        double keyValue = lastNode.get(dia);
        int LastSmall = begin-1;
        for (int i=begin; i<end-1; i++) {
            if (set.get(i).get(dia) < keyValue) {
                exchange(++LastSmall, i);
            }
        }
        exchange(end-1, ++LastSmall);
        if (midPos == LastSmall)
            return set.get(midPos);
        else if (midPos < LastSmall)
            return findMiD(begin, LastSmall, flag);
        else
            return findMiD(LastSmall+1, end, flag);
    }

    KDNodeSet findLeft() {
        return getSubSet(0, midPos);
    }

    KDNodeSet findRight() {
        return getSubSet(midPos+1, set.size());
    }

    KDNodeSet getSubSet(int begin, int end) {
        KDNodeSet subSet = new KDNodeSet();
        for (int i=begin; i<end; i++)
            subSet.add(set.get(i));
        return subSet;
    }

    private void exchange(int pos1, int pos2) {
        KDNode temp = set.get(pos1);
        set.set(pos1, set.get(pos2));
        set.set(pos2, temp);
    }

    public KDNode findNearest(KDNode node, float distance) {
        KDNode ans = null;
        for (Iterator<KDNode>ip = set.iterator(); ip.hasNext(); ip.next()) {
            if (node.distance((KDNode) ip) < distance) {
                ans = (KDNode) ip;
                distance = node.distance((KDNode) ip);
            }
        }
        return ans;
    }

    public static void main(String args[]) {
        KDNodeSet set = new KDNodeSet();
        for (int i=10; i>=0; i--) {
            KDNode node = new KDNode();
            node.set(0, i);
            set.add(node);
        }
        System.out.println(set.findMiD(1));
        System.out.println(set);
    }
}

KDTreeNode:

class KDTreeNode {//KD樹的每一個節點,保存有中位數的值,在該層有的集合和父子節點引用
    int flag = 0;
    KDNode value;
    KDTreeNode father;
    KDNodeSet set;
    KDTreeNode left;
    KDNodeSet leftSet;
    KDTreeNode right;
    KDNodeSet rightSet;

    KDTreeNode(KDTreeNode father, KDNodeSet set) {
        this.father = father;
        this.set = set;
        run();
    }

    KDTreeNode(KDTreeNode father, KDNodeSet set, int flag) {
        this.father = father;
        this.set = set;
        this.flag = flag;
        run();
    }

    void run() {
        if (set.set.size() == 0) return;

        value = set.findMiD(flag);
        leftSet = set.findLeft();
        rightSet = set.findRight();
        left = new KDTreeNode(this, leftSet, flag+1);
        right = new KDTreeNode(this, rightSet, flag+1);
    }

    KDTreeNode getFather() {return father;}
    KDTreeNode getLeft() {return left;}
    KDTreeNode getRight() {return right;}
    KDNode getValue() {return value;}
    int getFlag() {return flag;}
}

KDTree:

public class KDTree {
    KDNodeSet set;
    KDTreeNode root;
    static Stack<KDTreeNode> pathStack;
    KDTree() {
        set = new KDNodeSet();
        pathStack = new Stack<KDTreeNode>();
    }
    void addNode(KDNode node) {
        set.add(node);
    }
    void BuildTree() {
        root = new KDTreeNode(null, set);//已經建好的KDTree
    }

    void find(KDNode node) {
        KDTreeNode position = GetLeaf(root,node);//找到最近鄰的葉子節點。
        KDTreeNode NearestNode = GetNearest(node, position, position.getFather());
        System.out.println(NearestNode.getValue());
    }

    private KDTreeNode GetLeaf(KDTreeNode Tnode, KDNode node) {
        int dia = Tnode.flag%KDNode.dimension;
        if (Tnode.getValue() == null) {
            return Tnode.getFather();//返回葉子節點
        } else if (node.get(dia) < Tnode.getValue().get(dia)) {
            return GetLeaf(Tnode.getLeft(), node);
        }else
            return GetLeaf(Tnode.getRight(), node);
    }

    KDTreeNode GetNearest(KDNode node, KDTreeNode nearest, KDTreeNode spiltPoint) {
        KDTreeNode CurNearest = nearest;
        if (!pathStack.contains(spiltPoint)) {//如果該節點未被便利過,則進行遍歷,避免父子之間死循環
            if (spiltPoint == null)
                return nearest;
//            if (spiltPoint == root) {//遍歷終止
//                if (node.distance(nearest.getValue()) > node.distance(spiltPoint.getValue()))
//                    return root;//如果root節點比當前最近節點近則返回root節點
//                else //否則返回當前最近節點
//                    return nearest;
//            }
            pathStack.push(spiltPoint);
            if (node.distance(nearest.getValue()) > node.distance(spiltPoint.getValue())) {
                CurNearest = spiltPoint;
            }//如果當前節點距離待測數據比較近,更新之。

            if (node.isTangential(CurNearest.getValue(), spiltPoint.getValue(), spiltPoint.getFlag())) {
                //發生相交的情況
                //去查找最鄰近節點的兄弟節點
                //if (spiltPoint.getLeft() == nearest && spiltPoint.getRight().getValue() != null) {
                if(nearest.getValue().get(spiltPoint.getFlag()%KDNode.dimension) < spiltPoint.getValue().get(spiltPoint.getFlag()%KDNode.dimension)
                            && spiltPoint.getRight().getValue() != null){
                    return GetNearest(node, CurNearest, spiltPoint.getRight());
                } //else if (spiltPoint.getRight() == nearest && spiltPoint.getLeft().getValue() != null) {
                else if (nearest.getValue().get(spiltPoint.getFlag()%KDNode.dimension) >= spiltPoint.getValue().get(spiltPoint.getFlag()%KDNode.dimension)
                        && spiltPoint.getRight().getValue() != null) {
                    return GetNearest(node, CurNearest, spiltPoint.getLeft());
                }
                //當不屬於兄弟節點時,查找所有子節點,類似於先序遍歷
                else {
                    if (spiltPoint.getRight().getValue() != null) {
                        return GetNearest(node, CurNearest, spiltPoint.getRight());
                    }
                    if (spiltPoint.getLeft().getValue() != null) {
                        return GetNearest(node, CurNearest, spiltPoint.getLeft());
                    }
                }
            }
        }
        //遞歸遍歷父節點
        pathStack.pop();
        return GetNearest(node, CurNearest, spiltPoint.getFather());
    }
        public static void main(String[] args) {
        KDTree tree = new KDTree();
        KDNode node1 = new KDNode(2);
        node1.set(0,2);
        node1.set(1,3);
        tree.addNode(node1);
         node1 = new KDNode();
        node1.set(0,5);
        node1.set(1,4);
        tree.addNode(node1);
         node1 = new KDNode();
        node1.set(0,9);
        node1.set(1,6);
        tree.addNode(node1);
         node1 = new KDNode();
        node1.set(0,4);
        node1.set(1,7);
        tree.addNode(node1);
         node1 = new KDNode();
        node1.set(0,8);
        node1.set(1,1);
        tree.addNode(node1);
         node1 = new KDNode();
        node1.set(0,7);
        node1.set(1,2);
        tree.addNode(node1);

        tree.BuildTree();

        node1 = new KDNode();
        node1.set(0, (float) 2.1);
        node1.set(1, (float) 3.1);
        tree.find(node1);

        node1 = new KDNode();
        node1.set(0, (float) 2);
        node1.set(1, (float) 4.5);
        tree.find(node1);
    }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章