很久沒寫含這麼多stl的程序了,很故意的用set,map,vector,熟手一下。
也記錄一下吧,雖然寫得比較渣。
三個文件:
測試數據:data.txt
D1 Sunny Hot High Weak No
D2 Sunny Hot High Strong No
D3 Overcast Hot High Weak Yes
D4 Rain Mild High Weak Yes
D5 Rain Cool Normal Weak Yes
D6 Rain Cool Normal Strong No
D7 Overcast Cool Normal Strong Yes
D8 Sunny Mild High Weak No
D9 Sunny Cool Normal Weak Yes
D10 Rain Mild Normal Weak Yes
D11 Sunny Mild Normal Strong Yes
D12 Overcast Mild High Strong Yes
D13 Overcast Hot Normal Weak Yes
D14 Rain Mild High Strong No
#ifndef ID3_H
#define ID3_H
#include<fstream>
#include<iostream>
#include<vector>
#include<map>
#include<set>
#include<cmath>
using namespace std;
const int DataRow=14;
const int DataColumn=6;
struct Node
{
double value;//代表此時yes的概率。
int attrid;
Node * parentNode;
vector<Node*> childNode;
};
#endif
程序源文件id3.cpp
#include "id3.h"
string DataTable[DataRow][DataColumn];
map<string,int> str2int;
set<int> S;
set<int> Attributes;
string attrName[DataColumn]={"Day","Outlook","Temperature","Humidity","Wind","PlayTennis"};
string attrValue[DataColumn][DataRow]=
{
{},//D1,D2這個屬性不需要
{"Sunny","Overcast","Rain"},
{"Hot","Mild","Cool"},
{"High","Normal"},
{"Weak","Strong"},
{"No","Yes"}
};
int attrCount[DataColumn]={14,3,3,2,2,2};
double lg2(double n)
{
return log(n)/log(2);
}
void Init()
{
ifstream fin("data.txt");
for(int i=0;i<14;i++)
{
for(int j=0;j<6;j++)
{
fin>>DataTable[i][j];
}
}
fin.close();
for(int i=1;i<=5;i++)
{
str2int[attrName[i]]=i;
for(int j=0;j<attrCount[i];j++)
{
str2int[attrValue[i][j]]=j;
}
}
for(int i=0;i<DataRow;i++)
S.insert(i);
for(int i=1;i<=4;i++)
Attributes.insert(i);
}
double Entropy(const set<int> &s)
{
double yes=0,no=0,sum=s.size(),ans=0;
for(set<int>::iterator it=s.begin();it!=s.end();it++)
{
string s=DataTable[*it][str2int["PlayTennis"]];
if(s=="Yes")
yes++;
else
no++;
}
if(no==0||yes==0)
return ans=0;
ans=-yes/sum*lg2(yes/sum)-no/sum*lg2(no/sum);
return ans;
}
double Gain(const set<int> & example,int attrid)
{
int attrcount=attrCount[attrid];
double ans=Entropy(example);
double sum=example.size();
set<int> * pset=new set<int>[attrcount];
for(set<int>::iterator it=example.begin();it!=example.end();it++)
{
pset[str2int[DataTable[*it][attrid]]].insert(*it);
}
for(int i=0;i<attrcount;i++)
{
ans-=pset[i].size()/sum*Entropy(pset[i]);
}
return ans;
}
int FindBestAttribute(const set<int> & example,const set<int> & attr)
{
double mx=0;
int k=-1;
for(set<int>::iterator i=attr.begin();i!=attr.end();i++)
{
double ret=Gain(example,*i);
if(ret>mx)
{
mx=ret;
k=*i;
}
}
if(k==-1)
cout<<"FindBestAttribute error!"<<endl;
return k;
}
Node * Id3_solution(set<int> example,set<int> & attributes,Node * parent)
{
Node *now=new Node;//創建樹節點。
now->parentNode=parent;
if(attributes.empty())//如果此時屬性列表已用完,即爲空,則返回。
return now;
/*
* 統計一下example,如果都爲正或者都爲負則表示已經抵達決策樹的葉子節點
* 葉子節點的特徵是有childNode爲空。
*/
int yes=0,no=0,sum=example.size();
for(set<int>::iterator it=example.begin();it!=example.end();it++)
{
string s=DataTable[*it][str2int["PlayTennis"]];
if(s=="Yes")
yes++;
else
no++;
}
if(yes==sum||yes==0)
{
now->value=yes/sum;
return now;
}
/*找到最高信息增益的屬性並將該屬性從attributes集合中刪除*/
int bestattrid=FindBestAttribute(example,attributes);
now->attrid=bestattrid;
attributes.erase(attributes.find(bestattrid));
/*將exmple根據最佳屬性的不同屬性值分成幾個分支,每個分支有即一個子樹*/
vector< set<int> > child=vector< set<int> >(attrCount[bestattrid]);
for(set<int>::iterator i=example.begin();i!=example.end();i++)
{
int id=str2int[DataTable[*i][bestattrid]];
child[id].insert(*i);
}
for(int i=0;i<child.size();i++)
{
Node * ret=Id3_solution(child[i],attributes,now);
now->childNode.push_back(ret);
}
return now;
}
int main()
{
Init();
Node * Root=Id3_solution(S,Attributes,NULL);
return 0;
}