C++ 實現決策樹 ID3 算法

Reference

一步一步詳解ID3和C4.5的C++實現
決策樹之ID3算法

Notes

  • ID3只能處理離散性的屬性,屬性的每一種取值就剛好對應於樹上結點的一個分枝;
  • 樣本空間是字符串形式的,爲了處理方便,對每個屬性(包括 label)分別將它們的各種取值都映射爲一個整數(離散化,數字化),主要是我是用vectorvector的方式存樣本空間,這樣屬性的取值又剛好與下標對應;
  • 也是爲了方便,程序要求先輸入樣本個數、屬性個數,還有每個屬性的名字,在輸出決策樹時用;

Rendering

  • 輸入樣本空間

dt_input
其中,樣本空間上面的那一行是屬性名字,但沒有 label 的名字

  • 打印決策樹

dt_print

Sample Space

samspc
輸入的文本如下:

14 4
Outlook Temperature Humidity Windy
sunny hot high false no
sunny hot high true no
overcast hot high false yes
rain mild high false yes
rain cool normal false yes
rain cool normal true no
overcast cool normal true yes
sunny mild high false no
sunny cool normal false yes
rain mild normal false yes
sunny mild normal true yes
overcast mild high true yes
overcast hot normal false yes
rain mild high true no

Code

決策樹結點

struct dtNode
{
    int n_child;
    int attr_id;
    int label;
    vector<dtNode*> child;

    dtNode(): n_child(0), attr_id(-1), label(-1) {}
    ~dtNode() { child.clear(); }
};

決策樹類

typedef vector<vector<int> > SampleSpace;

class DecisionTree
{
    int n_dim;
    int n_sample;
    vector<string> attribute;
    SampleSpace samspc;
    vector<map<string,int> > dsc;
    vector<map<int,string> > rev;
    dtNode *root;

private:
    void _formatting(const vector<vector<string> > &in_sample_space);
    void _build(SampleSpace, dtNode *&);
    bool _same_class(const SampleSpace &);
    int _most_label(const SampleSpace &);
    int _max_gain_attr(const SampleSpace &);
    double _info_gain(const SampleSpace &, int attr);
    SampleSpace _drop_attr(SampleSpace, int attr);
    void _clear(dtNode *&);
    void _print_tree(dtNode *, int indent);

public:
    DecisionTree(int n_dimension, int n_sample);
    ~DecisionTree();
    void input();
    void build();
    void display();
};

完整代碼

#include <cmath>
#include <cstdlib>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <map>
#include <string>
#include <vector>

using std::cin;
using std::cout;
using std::endl;
using std::map;
using std::setw;
using std::string;
using std::vector;

/*------- 決策樹結點 -------*/

struct dtNode
{
    int n_child; // 子結點個數
    int attr_id; // 按哪個屬性來分類
    int label; // 分類標籤
    vector<dtNode*> child; // 子結點指針

    dtNode(): n_child(0), attr_id(-1), label(-1) {}
    ~dtNode() { child.clear(); }
};

/*------- 決策樹 -------*/

typedef vector<vector<int> > SampleSpace;

class DecisionTree
{
    // 樣本的維數(屬性數)
    int n_dim;
    // 樣本個數
    int n_sample;
    // 各屬性的名字 -> 用來打印
    vector<string> attribute;
    // 樣本空間
    SampleSpace samspc;
    // 離散化,把每一個屬性(包括標籤)的所有取值(字符串)都離散化成數字
    vector<map<string,int> > dsc;
    // 反離散化,根據離散化後的數字找回離散化前的字符串 -> 用來打印
    vector<map<int,string> > rev;
    // 決策樹樹根
    dtNode *root;

private:
    // 離散化,將輸入的字符串形式的樣本空間轉化成數字形式的樣本空間
    void _formatting(const vector<vector<string> > &in_sample_space);
    // 建樹
    void _build(SampleSpace, dtNode *&);
    // 判斷傳入的樣本空間是否所有樣本的 label 都一樣
    bool _same_class(const SampleSpace &);
    // 找到傳入的樣本空間裏佔最多數的 label
    int _most_label(const SampleSpace &);
    // 找到能獲得最大信息增益的(還沒被刪的)分類屬性
    int _max_gain_attr(const SampleSpace &);
    // 按某個屬性來算信息增益
    double _info_gain(const SampleSpace &, int attr);
    // 刪掉樣本空間的某個屬性(把那一列標爲負數)
    SampleSpace _drop_attr(SampleSpace, int attr);
    // 砍樹
    void _clear(dtNode *&);
    // 打印決策樹
    void _print_tree(dtNode *, int indent);

public:
    // constructor
    DecisionTree(int n_dimension, int n_sample);
    // destructor
    ~DecisionTree();
    // 輸入樣本空間(字符串形式)
    void input();
    // 建樹(驅動函數)
    void build();
    // 打印決策樹(驅動函數)
    void display();
};

/* 構造函數 */
DecisionTree::DecisionTree(int _d, int _s) :
    n_dim(_d), n_sample(_s), root(NULL)
{
    attribute = vector<string>(n_dim);
    samspc = SampleSpace(n_sample, vector<int>(n_dim + 1, 0));
    dsc = vector<map<string,int> >(n_dim + 1, map<string,int>());
    rev = vector<map<int,string> >(n_dim + 1, map<int,string>());
}

/* 析構函數 */
DecisionTree::~DecisionTree()
{
    attribute.clear();
    /* samspc 在 build() 結束時就被 clear() 了
     * samspc.clear();
     */
    dsc.clear();
    rev.clear();
    _clear(root);
}

/* 遞歸砍樹 */
void DecisionTree::_clear(dtNode *&rt)
{
    for(int i = 0; i < rt->child.size(); ++i)
        _clear(rt->child[i]);
    delete rt;
    rt = NULL;
}

/* 輸入樣本空間 */
void DecisionTree::input()
{
    cout << "\nInput the sample space"
        << "\nPlease ensure that one sample per line, and the label be the last one\n"
        << endl;
    // 讀入屬性的名字
    for(int i = 0; i < n_dim; ++i)
        cin >> attribute[i];

    vector<vector<string> > buf =
        vector<vector<string> >(n_sample, vector<string>(n_dim + 1));
    // 讀入屬性值和 label
    for(int i = 0; i < n_sample; ++i)
        for(int j = 0; j <= n_dim; ++j) // including the label
            cin >> buf[i][j];
    // 離散化樣本空間
    // 並存進 samspc
    _formatting(buf);
    // 字符串的樣本空間已沒用
    buf.clear();
}

/* 離散化樣本空間 */
void DecisionTree::_formatting(const vector<vector<string> > &buf)
{
    for(int d = 0; d <= n_dim; ++d) // including the label
    {
        int cnt = 0; // 離散化標號從 1 開始
        for(int i = 0; i < n_sample; ++i)
        {
            if(!dsc[d][buf[i][d]])
            {
                dsc[d][buf[i][d]] = ++cnt;
                rev[d][cnt] = buf[i][d];
            }
            samspc[i][d] = dsc[d][buf[i][d]];
        }
    }
}

/* 建樹 -> 調 _build() 來建 */
void DecisionTree::build()
{
    _build(samspc, root);
    samspc.clear();
}

/* 真·建樹 */
void DecisionTree::_build(SampleSpace sp, dtNode *&rt)
{
    rt = new dtNode();
    // 如果剩下的 label 全都相同
    // 直接讓這個結點成爲葉子
    // 就預測爲這個 label
    if(_same_class(sp))
    {
        rt->label = sp[0].back();
        return;
    }

    // 如果並不是所有 label 都相同
    // 但是已經沒有可用的屬性(全被刪掉)
    // 那這個結點也是葉子
    // 預測的 label 是佔數最多的那個 label
    int minus_cnt = 0;
    for(int i = 0; i < n_dim; ++i)
        minus_cnt += (sp[0][i] < 0 ? 1 : 0);
    // 因爲刪屬性只是把那一列屬性標成負數
    // 所以如果值爲負數的列數等於總的屬性數
    // 就意味着所有屬性都被刪了
    if(minus_cnt == n_dim) // no attribution left
    {
        rt->label = _most_label(sp);
        return;
    }

    // 選一個能獲得最大信息增益的屬性來分割
    rt->attr_id = _max_gain_attr(sp);
    // 理論上這個屬性有多少種取值可能
    // 這個結點就有多少個子結點
    rt->n_child = dsc[rt->attr_id].size();
    // 刪掉一列屬性後得到一個新樣本空間
    SampleSpace new_sp = _drop_attr(sp, rt->attr_id), sub;

    for(int i = 0; i < rt->n_child; ++i) // 枚舉這個屬性的所有可能取值
    {
        // 把樣本空間中這個屬性的取值是 i 的樣本找出來
        // 組成一個子樣本空間 sub
        for(int j = 0; j < sp.size(); ++j)
            if(sp[j][rt->attr_id] == i + 1)
                sub.push_back(new_sp[j]);
        // 如果樣本空間非空
        // 那對應的子結點才真的存在
        // 遞歸下去建樹
        if(!sub.empty())
        {
            rt->child.push_back(NULL);
            _build(sub, rt->child[rt->child.size() - 1]);
            sub.clear();
        }
    }
    // 重新數真正的子結點個數
    rt->n_child = rt->child.size();
    rt->child.resize(rt->n_child);
    // 樣本空間已沒用
    new_sp.clear();
    sp.clear();
}

/* 判斷樣本空間的樣本是否都是同一個 label */
bool DecisionTree::_same_class(const SampleSpace &sp)
{
    int lb = -1;
    for(int i = 0; i < sp.size(); ++i)
        if(lb == -1)
            lb = sp[i].back();
        else if(lb != sp[i].back())
            return false;
    return true;
}

/* 找到樣本空間裏佔數最多的那個 label */
int DecisionTree::_most_label(const SampleSpace &sp)
{
    int n_label = dsc[n_dim].size();
    int *cnt = new int[n_label + 1];
    for(int i = 0; i <= n_label; ++i)
        cnt[i] = 0;

    for(int i = 0; i < sp.size(); ++i)
        ++cnt[sp[i].back()];
    int res = 0;
    for(int i = 1; i <= n_label; ++i)
        if(cnt[i] > cnt[res])
            res = i;
    delete[] cnt;
    return res;
}

/* 找能獲得最大信息增益的那個屬性 */
int DecisionTree::_max_gain_attr(const SampleSpace &sp)
{
    int ans = 0;
    double big_gain = -100.0;
    for(int atr = 0; atr < sp[0].size() - 1; ++atr) // 最後一列是 label
    {
        if(sp[0][atr] < 0) // 已被刪的屬性不考慮
            continue;
        double g = _info_gain(sp, atr);
        if(g > big_gain)
        {
            big_gain = g;
            ans = atr;
        }
    }
    return ans;
}

/* 計算信息增益(其實只是信息熵,但效果相同) */
double DecisionTree::_info_gain(const SampleSpace &sp, int atr)
{
    double ans = 0.0;
    // label 取值種數‘屬性取值種數
    int n_label = dsc[n_dim].size(), n_attr_val = dsc[atr].size();
    // 統計數組 -> 統計各種 label 取值的出現次數
    int *cnt = new int[n_label + 1];

    for(int val = 1; val <= n_attr_val; ++val)
    {
        int tot = 0;
        double ent = 0.0;
        for(int j = 0; j <= n_label; ++j)
            cnt[j] = 0;

        for(int j = 0; j < sp.size(); ++j)
            if(sp[j][atr] == val)
            {
                ++cnt[sp[j].back()];
                ++tot;
            }

        for(int j = 1; j <= n_label; ++j)
        {
            double p = (double)cnt[j] / (double)tot;
            ent -= p * log2(p); // NOT +=
        }
        ans += (double)tot / (double)sp.size() * ent;
    }
    delete[] cnt;
    return ans;
}

/* 刪掉樣本空間的一列屬性 -> 把那一列標稱負數 */
SampleSpace DecisionTree::_drop_attr(SampleSpace sp, int atr)
{
    for(int i = 0; i < sp.size(); ++i)
        sp[i][atr] = /*-sp[i][atr]*/ -1;
    return sp;
}

/* 打印決策樹 */
void DecisionTree::display()
{
    cout << "--- Decision Tree ---" << endl;
    _print_tree(root, 0);
}

/* 真·打印決策樹 */
void DecisionTree::_print_tree(dtNode *rt, int ind)
{
    // 縮進
    for(int i = 0; i < ind; ++i)
        cout << ' ';

    // 如果是葉子
    if(!rt->n_child)
    {
        cout << rev[n_dim][rt->label] << endl;
        return;
    }
    // 打印屬性名
    cout << attribute[rt->attr_id] << endl;
    // 縮進加上屬性名的長度
    ind += attribute[rt->attr_id].length();

    int len = 0, idx = rt->attr_id;
    // 找最長的字符串的長度
    for(map<int,string>::iterator it = rev[idx].begin(); it != rev[idx].end(); ++it)
        if(len < it->second.length())
            len = it->second.length();

    for(int ch = 0; ch < rt->n_child; ++ch)
    {
        for(int j = 0; j < ind; ++j)
            cout << ' ';
        cout << setw(len) << rev[rt->attr_id][ch+1] << endl;
        _print_tree(rt->child[ch], ind + len + 1);
    }
}

/*------- Main Function -------*/

int main()
{
    cout << "--- Decision Tree ---\n";
    int sam, dim;
    cout << "\nNumber of samples in the sample space: ";
    cin >> sam;
    cout << "\nNumber of dimensions: ";
    cin >> dim;
    DecisionTree x_x(dim, sam);
    x_x.input();
    x_x.build();
    system("CLS");
    x_x.display();
    system("PAUSE");
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章