理論介紹
kd樹(K-dimension tree)是一種對k維空間中的實例點進行存儲以便對其進行快速檢索的樹形數據結構。kd樹是是一種二叉樹,表示對k維空間的一個劃分,構造kd樹相當於不斷地用垂直於座標軸的超平面將K維空間切分,構成一系列的K維超矩形區域。kd樹的每個結點對應於一個k維超矩形區域。利用kd樹可以省去對大部分數據點的搜索,從而減少搜索的計算量。
- kd樹的算法步驟
- kd樹可以較大的減少搜索空間,提高搜索效率,在knn,光線追蹤這些需要大量搜索的算法中具有重要的應用前景。
算法實現
- 使用C++ 對kd樹進行了實現,沒有使用模板,如有錯誤歡迎指出~
#pragma once
#include <iostream>
#include <vector>
#include <algorithm>
#include <fstream>
using namespace std;
int dimUsed;
// kd樹的節點定義
struct KdNode
{
KdNode * parent;
KdNode * leftChild;
KdNode * rightChild;
vector<int> val; //存儲的數據
int axis; // 使用的軸
KdNode(vector<int> data, int ax)
{
val = data;
axis = ax;
parent = NULL;
leftChild = NULL;
rightChild = NULL;
}
};
// 用於以第n位進行比較,方便重載
bool cmp(vector<int> a, vector<int> b)
{
if (a[dimUsed] < b[dimUsed])
return true;
return false;
}
ostream & operator<<(ostream & os, vector<int> vi)
{
os << "(";
for (int i = 0; i < vi.size(); i++)
cout << vi[i] << ",";
os << ")";
return os;
}
// 導入數據,第一行是數據數量和維度,後面跟着num行dim維度的數據,每個一行
bool loadData(string filename, vector<vector<int> > &data)
{
ifstream infs(filename);
if (infs.is_open())
{
int num,dim;
infs >> num>>dim;
vector<int> d(dim);
for (int i = 0; i < num; i++)
{
for (int j = 0; j < dim; j++)
infs >> d[j];
data.push_back(d);
}
return true;
}
return false;
}
// 計算N維向量距離
int disVector(vector<int> a, vector<int> b)
{
int sum = 0;
for (int i = 0; i < a.size(); i++)
sum += (a[i] - b[i])*(a[i] - b[i]);
return sum;
}
// kd 樹的類
class KdTree
{
private:
int dimension;
vector<vector<int> > data;
KdNode * root;
public:
KdTree(vector<vector<int> > d, int dim)
{
dimension = dim;
data = d;
}
void createTree()
{
// 遞歸建樹
root = createTreeNode(0, data.size()-1,0);
}
// create Kd Tree struct
KdNode * createTreeNode(int left, int right,int dim)
{
if (right < left)
return NULL;
dimUsed = dim;
// 按照k維進行排序
sort(data.begin() + left, data.begin() + right+1, cmp);
int mid = (left + right+1) / 2;
KdNode * r = new KdNode(data[mid], dim);
r->leftChild = createTreeNode(left, mid - 1, (dim + 1) % dimension);
if (r->leftChild != NULL)
r->leftChild->parent = r;
r->rightChild = createTreeNode(mid + 1, right, (dim + 1) % dimension);
if (r->rightChild != NULL)
r->rightChild->parent = r;
return r;
}
void printKdTree()
{
printKdTreeNode(root);
}
void printKdTreeNode(KdNode * r)
{
if (r == NULL)
return;
printKdTreeNode(r->leftChild);
cout << r->val << "\t";
printKdTreeNode(r->rightChild);
}
// 查找kd樹
KdNode * searchKdTree(vector<int> d)
{
int dim = 0,minDis = 10000000;
KdNode * r = root;
KdNode * tmp;
while (r != NULL)
{
tmp = r;
if (d[dim] < r->val[dim])
r = r->leftChild;
else
r = r->rightChild;
dim = (dim + 1) % dimension;
}
// 找到屬於的那個節點
minDis = min(disVector(d, tmp->val),minDis);
KdNode * nearNode = tmp;
cout << endl<<"nearest node: "<<tmp->val << endl;
// 退回到根節點
while (tmp->parent != NULL)
{
tmp = tmp->parent;
// 判斷父節點是否更近,如果近,更新最近節點
if (disVector(tmp->val, d) < minDis)
{
nearNode = tmp;
minDis = disVector(tmp->val, d);
}
cout << "now parent node: " << tmp->val << endl;
KdNode * son;
// 判斷當前軸與點的距離,如果小於minDis,則進行到另一半進行查找
if (abs(tmp->val[tmp->axis] - d[tmp->axis]) < minDis)
{
if (tmp->val[tmp->axis] > d[tmp->axis])
son = tmp->rightChild;
else
son = tmp->leftChild;
searchKdTreeNode(d, minDis, nearNode, son);
}
}
// 對根節點的另外半邊子樹進行搜索
/*if (abs(tmp->val[tmp->axis] - d[tmp->axis]) < minDis)
{
if (tmp->val[tmp->axis] > d[tmp->axis])
tmp = tmp->rightChild;
else
tmp = tmp->leftChild;
searchKdTreeNode(d, minDis, nearNode, tmp);
}*/
return nearNode;
}
// 查找當前節點下的最近點
void searchKdTreeNode(vector<int> d,int &minDis,KdNode * &nearNode,KdNode * tmp)
{
// 遞歸終止
if (tmp == NULL)
return;
cout << "now node: " << tmp->val << endl;
// 判斷當前節點是否小於
if (disVector(tmp->val, d) < minDis)
{
minDis = disVector(tmp->val, d);
nearNode = tmp;
}
// 如果軸與節點的距離小於minDis,則兩個半邊都需要搜索,否則只需要搜索一個半邊
if (abs(tmp->val[tmp->axis] - d[tmp->axis]) < minDis)
{
searchKdTreeNode(d, minDis, nearNode, tmp->leftChild);
searchKdTreeNode(d, minDis, nearNode, tmp->rightChild);
}
else
{
// 選擇搜索的一個半邊
if (tmp->val[tmp->axis] > d[tmp->axis])
searchKdTreeNode(d, minDis, nearNode, tmp->leftChild);
else
searchKdTreeNode(d, minDis, nearNode, tmp->rightChild);
}
}
};
// 測試kd樹
void testKdTree()
{
vector<vector<int> > data;
loadData("kd.txt", data);
KdTree * kdtree = new KdTree(data, data[0].size());
kdtree->createTree();
kdtree->printKdTree();
cout << endl;
vector<int> vi(2);
cin >> vi[0] >> vi[1];
KdNode * r = kdtree->searchKdTree(vi);
cout << r->val << endl;
}
/*
測試數據
6 2
2 3
5 4
9 6
4 7
8 1
7 2
*/
如有錯誤,歡迎指出~