決策樹學習 之 ID3 C++STL代碼實現

轉自:http://blog.csdn.net/lee_cv/article/details/9197601 

決策樹學習 之 ID3 C++STL代碼實現

很久沒寫含這麼多stl的程序了,很故意的用set,map,vector,熟手一下。

也記錄一下吧,雖然寫得比較渣。


三個文件:

測試數據:data.txt

[plain] view plaincopy
  1. D1    Sunny        Hot    High        Weak    No  
  2. D2    Sunny        Hot    High        Strong    No  
  3. D3    Overcast    Hot    High        Weak    Yes  
  4. D4    Rain        Mild    High        Weak    Yes  
  5. D5    Rain        Cool    Normal        Weak    Yes  
  6. D6    Rain        Cool    Normal        Strong    No  
  7. D7    Overcast    Cool    Normal        Strong    Yes  
  8. D8    Sunny        Mild    High        Weak    No  
  9. D9    Sunny        Cool    Normal        Weak    Yes  
  10. D10    Rain        Mild    Normal        Weak    Yes  
  11. D11    Sunny        Mild    Normal        Strong    Yes  
  12. D12    Overcast    Mild    High        Strong    Yes  
  13. D13    Overcast    Hot    Normal        Weak    Yes  
  14. D14    Rain        Mild    High        Strong    No  


程序頭文件:id3.h
  1. #ifndef ID3_H  
  2. #define ID3_H  
  3. #include<fstream>  
  4. #include<iostream>  
  5. #include<vector>  
  6. #include<map>  
  7. #include<set>  
  8. #include<cmath>  
  9. using namespace std;  
  10. const int DataRow=14;  
  11. const int DataColumn=6;  
  12. struct Node  
  13. {  
  14.     double value;//代表此時yes的概率。  
  15.     int attrid;  
  16.     Node * parentNode;  
  17.     vector<Node*> childNode;  
  18. };  
  19. #endif  

程序源文件id3.cpp

  1. #include "id3.h"  
  2.   
  3. string DataTable[DataRow][DataColumn];  
  4. map<string,int> str2int;  
  5. set<int> S;  
  6. set<int> Attributes;  
  7. string attrName[DataColumn]={"Day","Outlook","Temperature","Humidity","Wind","PlayTennis"};  
  8. string attrValue[DataColumn][DataRow]=  
  9. {  
  10.     {},//D1,D2這個屬性不需要  
  11.     {"Sunny","Overcast","Rain"},  
  12.     {"Hot","Mild","Cool"},  
  13.     {"High","Normal"},  
  14.     {"Weak","Strong"},  
  15.     {"No","Yes"}  
  16. };  
  17. int attrCount[DataColumn]={14,3,3,2,2,2};  
  18. double lg2(double n)  
  19. {  
  20.     return log(n)/log(2);  
  21. }  
  22. void Init()  
  23. {  
  24.     ifstream fin("data.txt");  
  25.     for(int i=0;i<14;i++)  
  26.     {  
  27.       for(int j=0;j<6;j++)  
  28.       {  
  29.           fin>>DataTable[i][j];  
  30.       }  
  31.     }  
  32.     fin.close();  
  33.     for(int i=1;i<=5;i++)  
  34.     {  
  35.         str2int[attrName[i]]=i;  
  36.         for(int j=0;j<attrCount[i];j++)  
  37.         {  
  38.             str2int[attrValue[i][j]]=j;  
  39.         }  
  40.     }  
  41.     for(int i=0;i<DataRow;i++)  
  42.       S.insert(i);  
  43.     for(int i=1;i<=4;i++)  
  44.       Attributes.insert(i);  
  45. }  
  46.   
  47. double Entropy(const set<int> &s)  
  48. {  
  49.     double yes=0,no=0,sum=s.size(),ans=0;  
  50.     for(set<int>::iterator it=s.begin();it!=s.end();it++)  
  51.     {  
  52.         string s=DataTable[*it][str2int["PlayTennis"]];  
  53.         if(s=="Yes")  
  54.           yes++;  
  55.         else  
  56.           no++;  
  57.     }  
  58.     if(no==0||yes==0)  
  59.       return ans=0;  
  60.     ans=-yes/sum*lg2(yes/sum)-no/sum*lg2(no/sum);  
  61.     return ans;  
  62. }  
  63. double Gain(const set<int> & example,int attrid)  
  64. {  
  65.     int attrcount=attrCount[attrid];  
  66.     double ans=Entropy(example);  
  67.     double sum=example.size();  
  68.     set<int> * pset=new set<int>[attrcount];  
  69.     for(set<int>::iterator it=example.begin();it!=example.end();it++)  
  70.     {  
  71.         pset[str2int[DataTable[*it][attrid]]].insert(*it);  
  72.     }  
  73.     for(int i=0;i<attrcount;i++)  
  74.     {  
  75.         ans-=pset[i].size()/sum*Entropy(pset[i]);  
  76.     }  
  77.     return ans;  
  78. }  
  79. int FindBestAttribute(const set<int> & example,const set<int> & attr)  
  80. {  
  81.     double mx=0;  
  82.     int k=-1;  
  83.     for(set<int>::iterator i=attr.begin();i!=attr.end();i++)  
  84.     {  
  85.         double ret=Gain(example,*i);  
  86.         if(ret>mx)  
  87.         {  
  88.             mx=ret;  
  89.             k=*i;  
  90.         }  
  91.     }  
  92.     if(k==-1)  
  93.       cout<<"FindBestAttribute error!"<<endl;  
  94.     return k;  
  95. }  
  96. Node * Id3_solution(set<int> example,set<int> & attributes,Node * parent)  
  97. {  
  98.     Node *now=new Node;//創建樹節點。  
  99.     now->parentNode=parent;  
  100.     if(attributes.empty())//如果此時屬性列表已用完,即爲空,則返回。  
  101.       return now;  
  102.   
  103.     /* 
  104.      * 統計一下example,如果都爲正或者都爲負則表示已經抵達決策樹的葉子節點 
  105.      * 葉子節點的特徵是有childNode爲空。 
  106.      */  
  107.     int yes=0,no=0,sum=example.size();  
  108.     for(set<int>::iterator it=example.begin();it!=example.end();it++)  
  109.     {  
  110.         string s=DataTable[*it][str2int["PlayTennis"]];  
  111.         if(s=="Yes")  
  112.           yes++;  
  113.         else  
  114.           no++;  
  115.     }  
  116.     if(yes==sum||yes==0)  
  117.     {  
  118.         now->value=yes/sum;  
  119.         return now;  
  120.     }  
  121.       
  122.   
  123.     /*找到最高信息增益的屬性並將該屬性從attributes集合中刪除*/  
  124.     int bestattrid=FindBestAttribute(example,attributes);  
  125.     now->attrid=bestattrid;  
  126.     attributes.erase(attributes.find(bestattrid));  
  127.       
  128.     /*將exmple根據最佳屬性的不同屬性值分成幾個分支,每個分支有即一個子樹*/  
  129.     vector< set<int> > child=vector< set<int> >(attrCount[bestattrid]);  
  130.     for(set<int>::iterator i=example.begin();i!=example.end();i++)  
  131.     {  
  132.         int id=str2int[DataTable[*i][bestattrid]];  
  133.         child[id].insert(*i);  
  134.     }  
  135.     for(int i=0;i<child.size();i++)  
  136.     {  
  137.         Node * ret=Id3_solution(child[i],attributes,now);  
  138.         now->childNode.push_back(ret);  
  139.     }  
  140.     return now;  
  141. }  
  142.   
  143. int main()  
  144. {  
  145.     Init();  
  146.     Node * Root=Id3_solution(S,Attributes,NULL);  
  147.     return 0;  
  148. }  
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章