KD樹的C++實現

理論介紹

kd樹(K-dimension tree)是一種對k維空間中的實例點進行存儲以便對其進行快速檢索的樹形數據結構。kd樹是是一種二叉樹,表示對k維空間的一個劃分,構造kd樹相當於不斷地用垂直於座標軸的超平面將K維空間切分,構成一系列的K維超矩形區域。kd樹的每個結點對應於一個k維超矩形區域。利用kd樹可以省去對大部分數據點的搜索,從而減少搜索的計算量。

  • kd樹的算法步驟
    這裏寫圖片描述
     
  • kd樹可以較大的減少搜索空間,提高搜索效率,在knn,光線追蹤這些需要大量搜索的算法中具有重要的應用前景。

算法實現

  • 使用C++ 對kd樹進行了實現,沒有使用模板,如有錯誤歡迎指出~
#pragma once
#include <iostream>
#include <vector>
#include <algorithm>
#include <fstream>
using namespace std;

int dimUsed;

// kd樹的節點定義
struct KdNode
{
    KdNode * parent;
    KdNode * leftChild;
    KdNode * rightChild;
    vector<int> val;    //存儲的數據
    int axis;       // 使用的軸
    KdNode(vector<int> data, int ax)
    {
        val = data;
        axis = ax;
        parent = NULL;
        leftChild = NULL;
        rightChild = NULL;
    }
};

// 用於以第n位進行比較,方便重載
bool cmp(vector<int> a, vector<int> b)
{
    if (a[dimUsed] < b[dimUsed])
        return true;
    return false;
}

ostream & operator<<(ostream & os, vector<int> vi)
{
    os << "(";
    for (int i = 0; i < vi.size(); i++)
        cout << vi[i] << ",";
    os << ")";
    return os;
}

// 導入數據,第一行是數據數量和維度,後面跟着num行dim維度的數據,每個一行
bool loadData(string filename, vector<vector<int> > &data)
{
    ifstream infs(filename);
    if (infs.is_open())
    {
        int num,dim;
        infs >> num>>dim;
        vector<int> d(dim);
        for (int i = 0; i < num; i++)
        {
            for (int j = 0; j < dim; j++)
                infs >> d[j];
            data.push_back(d);
        }
        return true;
    }
    return false;
}

// 計算N維向量距離
int disVector(vector<int> a, vector<int> b)
{
    int sum = 0;
    for (int i = 0; i < a.size(); i++)
        sum += (a[i] - b[i])*(a[i] - b[i]);
    return sum;
}

// kd 樹的類
class KdTree
{
private:
    int dimension;
    vector<vector<int> > data;
    KdNode * root;
public:
    KdTree(vector<vector<int> > d, int dim)
    {
        dimension = dim;
        data = d;
    }
    void createTree()
    {
        // 遞歸建樹
        root = createTreeNode(0, data.size()-1,0);
    }

    // create Kd Tree struct
    KdNode * createTreeNode(int left, int right,int dim)
    {
        if (right < left)
            return NULL;
        dimUsed = dim;
        // 按照k維進行排序
        sort(data.begin() + left, data.begin() + right+1, cmp);
        int mid = (left + right+1) / 2;
        KdNode * r = new KdNode(data[mid], dim);
        r->leftChild = createTreeNode(left, mid - 1, (dim + 1) % dimension);
        if (r->leftChild != NULL)
            r->leftChild->parent = r;
        r->rightChild = createTreeNode(mid + 1, right, (dim + 1) % dimension);
        if (r->rightChild != NULL)
            r->rightChild->parent = r;
        return r;
    }

    void printKdTree()
    {
        printKdTreeNode(root);
    }

    void printKdTreeNode(KdNode * r)
    {
        if (r == NULL)
            return;
        printKdTreeNode(r->leftChild);
        cout << r->val << "\t";
        printKdTreeNode(r->rightChild);
    }

    // 查找kd樹
    KdNode * searchKdTree(vector<int> d)
    {
        int dim = 0,minDis = 10000000;
        KdNode * r = root;
        KdNode * tmp;
        while (r != NULL)
        {
            tmp = r;
            if (d[dim] < r->val[dim])                           
                r = r->leftChild;   
            else
                r = r->rightChild;
            dim = (dim + 1) % dimension;
        }
        // 找到屬於的那個節點
        minDis = min(disVector(d, tmp->val),minDis);
        KdNode * nearNode = tmp;
        cout << endl<<"nearest node: "<<tmp->val << endl;
        // 退回到根節點
        while (tmp->parent != NULL)
        {
            tmp = tmp->parent;
            // 判斷父節點是否更近,如果近,更新最近節點
            if (disVector(tmp->val, d) < minDis)
            {
                nearNode = tmp;
                minDis = disVector(tmp->val, d);
            }
            cout << "now parent node: " << tmp->val << endl;
            KdNode * son;
            // 判斷當前軸與點的距離,如果小於minDis,則進行到另一半進行查找
            if (abs(tmp->val[tmp->axis] - d[tmp->axis]) < minDis)
            {
                if (tmp->val[tmp->axis] > d[tmp->axis])
                    son = tmp->rightChild;
                else
                    son = tmp->leftChild;
                searchKdTreeNode(d, minDis, nearNode, son);
            }
        }   
        // 對根節點的另外半邊子樹進行搜索
        /*if (abs(tmp->val[tmp->axis] - d[tmp->axis]) < minDis)
        {
            if (tmp->val[tmp->axis] > d[tmp->axis])
                tmp = tmp->rightChild;
            else
                tmp = tmp->leftChild;
            searchKdTreeNode(d, minDis, nearNode, tmp);
        }*/
        return nearNode;
    }

    // 查找當前節點下的最近點
    void searchKdTreeNode(vector<int> d,int &minDis,KdNode * &nearNode,KdNode * tmp)
    {
        // 遞歸終止
        if (tmp == NULL)
            return;
        cout << "now node: " << tmp->val << endl;
        // 判斷當前節點是否小於
        if (disVector(tmp->val, d) < minDis)
        {
            minDis = disVector(tmp->val, d);
            nearNode = tmp;
        }
        // 如果軸與節點的距離小於minDis,則兩個半邊都需要搜索,否則只需要搜索一個半邊
        if (abs(tmp->val[tmp->axis] - d[tmp->axis]) < minDis)
        {
            searchKdTreeNode(d, minDis, nearNode, tmp->leftChild);
            searchKdTreeNode(d, minDis, nearNode, tmp->rightChild);
        }
        else
        {
        // 選擇搜索的一個半邊
            if (tmp->val[tmp->axis] > d[tmp->axis])
                searchKdTreeNode(d, minDis, nearNode, tmp->leftChild);
            else
                searchKdTreeNode(d, minDis, nearNode, tmp->rightChild);
        }

    }

};

// 測試kd樹
void testKdTree()
{
    vector<vector<int> > data;
    loadData("kd.txt", data);
    KdTree * kdtree = new KdTree(data, data[0].size());
    kdtree->createTree();
    kdtree->printKdTree();
    cout << endl;
    vector<int> vi(2);
    cin >> vi[0] >> vi[1];
    KdNode * r = kdtree->searchKdTree(vi);
    cout << r->val << endl;
}

/*
測試數據
6 2
2 3
5 4
9 6
4 7
8 1
7 2
*/

github地址

如有錯誤,歡迎指出~

發佈了36 篇原創文章 · 獲贊 65 · 訪問量 15萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章