K Nearest Neighbor問題的解決——KD-TREE Implementation

命題一: 
已知的1000個整數的數組,給定一個整數,要求查證是否在數組中出現? 

命題二: 
已知1000個整數的數組,給定一個整數,要求查找數組中與之最接近的數字? 

命題三: 
已知1000個Point(包含X與Y座標)結構的數組,給定一個Point,要求查找數組中與之最接近(比如:歐氏距離最短)的點。 

命題四: 
已知1,000,000個向量,每個向量爲128維;給定一個向量,要求查找數組中與之最接近的K個向量 

  • 對於命題一,如果不考慮桶式、哈希等方式,常用的方法應該是排序後,使用折半查找。
  • 對於命題二,與命題一類似,比較折半查找得出的結果,以及附近的各一個元素,即可。整個過程相當於是把這個包含1000個數組的數據結構做成一顆二叉樹,最後只需比較葉子節點與其父節點即可。
  • 對於命題三、四其中命題三和四就是所謂的Nearest Neighbor問題。一種近似解決的方法就是KD-TREE


高維向量的KNN檢索問題,在圖像等多媒體內容搜索中是相當關鍵的。關於高維向量的討論,網上資料比較少;在此,我將一些心得分享給大家。 
與二叉樹相比,KD-TREE也採用類似的劃分方式,只不過樹中的各節點均是高維向量,因此劃分的方式,採用隨機或指定的方式選取一個維度,在該指定維度上進行劃分;整體的思想就是採用多個超平面對數據集空間進行兩兩切分,這一點,有點類似於數據挖掘中的決策樹。 

一個運用KD-TREE分割二維平面的DEMO如下: 

 

KD-Tree build的代碼如下: 
Java代碼  收藏代碼
  1. private ClusterKDTree(Clusterable[] points, int height, boolean randomSplit){  
  2.     if ( points.length == 1 ){  
  3.         cluster = points[0];  
  4.     }  
  5.     else {  
  6.         splitIndex = chooseSplitDimension//選取切分維度  
  7.             (points[0].getLocation().length,height,randomSplit);  
  8.         splitValue = chooseSplit(points,splitIndex);//選取切分值  
  9.               
  10.         Vector<Clusterable> left = new Vector<Clusterable>();  
  11.         Vector<Clusterable> right = new Vector<Clusterable>();  
  12.         for ( int i = 0; i < points.length; i++ ){  
  13.             double val = points[i].getLocation()[splitIndex];  
  14.             if ( val == splitValue && cluster == null ){  
  15.                 cluster = points[i];  
  16.             }  
  17.             else if ( val >= splitValue ){  
  18.                 right.add(points[i]);  
  19.             } else {  
  20.                 left.add(points[i]);  
  21.             }  
  22.         }  
  23.               
  24.         if ( right.size() > 0 ){  
  25.             this.right = new ClusterKDTree(right.toArray(new  
  26.             Clusterable[right.size()]),  
  27.             randomSplit ? splitIndex : height+1, randomSplit);  
  28.         }  
  29.         if ( left.size() > 0 ){  
  30.             this.left = new ClusterKDTree(left.toArray(new  
  31.             Clusterable[left.size()]),randomSplit ? splitIndex : height+1,  
  32.             randomSplit);  
  33.         }  
  34.     }  
  35. }  
  36.   
  37. private int chooseSplitDimension(int dimensionality,int height,boolean random){  
  38.     if ( !random ) return height % dimensionality;  
  39.     int rand = r.nextInt(dimensionality);  
  40.     while ( rand == height ){  
  41.         rand = r.nextInt(dimensionality);  
  42.     }  
  43.     return rand;  
  44. }  
  45.       
  46. private double chooseSplit(Clusterable points[],int splitIdx){  
  47.     double[] values = new double[points.length];  
  48.     for ( int i = 0; i < points.length; i++ ){  
  49.     values[i] = points[i].getLocation()[splitIdx];  
  50.     }  
  51.     Arrays.sort(values);  
  52.     return values[values.length/2];//選取中間值以保持樹的平衡  
  53. }  


構建完一顆KD-TREE之後,如何使用它來做KNN檢索呢?我用下面的圖來表示(20s的GIF動畫): 



使用KD-TREE,經過一次二分查找可以獲得Query的KNN(最近鄰)貪心解,代碼如下: 
Java代碼  收藏代碼
  1. private Clusterable restrictedNearestNeighbor(Clusterable point, SizedPriorityQueue<ClusterKDTree> values){  
  2.     if ( splitIndex == -1 ) {  
  3.         return cluster; //已近到葉子節點  
  4.     }  
  5.           
  6.     double val = point.getLocation()[splitIndex];  
  7.     Clusterable closest = null;  
  8.     if ( val >= splitValue && right != null || left == null ){  
  9.         //沿右邊路徑遍歷,並將左邊子樹放進隊列  
  10.         if ( left != null ){  
  11.             double dist = val - splitValue;  
  12.             values.add(left,dist);  
  13.         }  
  14.         closest = right.restrictedNearestNeighbor(point,values);  
  15.     }  
  16.     else if ( val < splitValue && left != null || right == null ) {  
  17.         //沿左邊路徑遍歷,並將右邊子樹放進隊列  
  18.         if ( right != null ){  
  19.             double dist = splitValue - val;  
  20.             values.add(right,dist);  
  21.         }  
  22.         closest = left.restrictedNearestNeighbor(point,values);  
  23.     }  
  24.     //current distance of the 'ideal' node  
  25.     double currMinDistance = ClusterUtils.getEuclideanDistance(closest,point);  
  26.     //check to see if the current node we've backtracked to is closer  
  27.     double currClusterDistance = ClusterUtils.getEuclideanDistance(cluster,point);  
  28.     if ( closest == null || currMinDistance > currClusterDistance ){  
  29.         closest = cluster;  
  30.         currMinDistance = currClusterDistance;  
  31.     }  
  32.     return closest;  
  33. }  


事實上,僅僅一次的遍歷會有不小的誤差,因此採用了一個優先級隊列來存放每次決定遍歷走向時,另一方向的節點。SizedPriorityQueue代碼的實現,可參考我的另一篇文章: 
http://grunt1223.iteye.com/blog/909739 

一種減少誤差的方法(BBF:Best Bin First)是回溯一定數量的節點: 
Java代碼  收藏代碼
  1. public Clusterable restrictedNearestNeighbor(Clusterable point, int numMaxBinsChecked){  
  2.     SizedPriorityQueue<ClusterKDTree> bins = new SizedPriorityQueue<ClusterKDTree>(50,true);  
  3.     Clusterable closest = restrictedNearestNeighbor(point,bins);  
  4.     double closestDist = ClusterUtils.getEuclideanDistance(point,closest);  
  5.     //System.out.println("retrieved point: " + closest + ", dist: " + closestDist);  
  6.     int count = 0;  
  7.     while ( count < numMaxBinsChecked && bins.size() > 0 ){  
  8.         ClusterKDTree nextBin = bins.pop();  
  9.     //System.out.println("Popping of next bin: " + nextBin);  
  10.     Clusterable possibleClosest = nextBin.restrictedNearestNeighbor(point,bins);  
  11.         double dist = ClusterUtils.getEuclideanDistance(point,possibleClosest);  
  12.         if ( dist < closestDist ){  
  13.         closest = possibleClosest;  
  14.         closestDist = dist;  
  15.     }  
  16.     count++;  
  17.     }  
  18.     return closest;  
  19. }  


可以用如下代碼進行測試: 
Java代碼  收藏代碼
  1. public static void main(String args[]){  
  2.     Clusterable clusters[] = new Clusterable[10];  
  3.     clusters[0] = new Point(0,0);  
  4.     clusters[1] = new Point(1,2);  
  5.     clusters[2] = new Point(2,3);  
  6.     clusters[3] = new Point(1,5);  
  7.     clusters[4] = new Point(2,5);  
  8.     clusters[5] = new Point(1,1);  
  9.     clusters[6] = new Point(3,3);  
  10.     clusters[7] = new Point(0,2);  
  11.     clusters[8] = new Point(4,4);  
  12.     clusters[9] = new Point(5,5);  
  13.     ClusterKDTree tree = new ClusterKDTree(clusters,true);  
  14.     //tree.print();  
  15.     Clusterable c = tree.restrictedNearestNeighbor(new Point(4,4),1000);  
  16.     System.out.println(c);  
  17. }  
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章