《機器學習》第三章決策樹學習 ID3算法 c++實現代碼

很久沒寫含這麼多stl的程序了,很故意的用set,map,vector,熟手一下。

也記錄一下吧,雖然寫得比較渣。


三個文件:

測試數據:data.txt

D1    Sunny        Hot    High        Weak    No
D2    Sunny        Hot    High        Strong    No
D3    Overcast    Hot    High        Weak    Yes
D4    Rain        Mild    High        Weak    Yes
D5    Rain        Cool    Normal        Weak    Yes
D6    Rain        Cool    Normal        Strong    No
D7    Overcast    Cool    Normal        Strong    Yes
D8    Sunny        Mild    High        Weak    No
D9    Sunny        Cool    Normal        Weak    Yes
D10    Rain        Mild    Normal        Weak    Yes
D11    Sunny        Mild    Normal        Strong    Yes
D12    Overcast    Mild    High        Strong    Yes
D13    Overcast    Hot    Normal        Weak    Yes
D14    Rain        Mild    High        Strong    No


程序頭文件:id3.h
#ifndef ID3_H
#define ID3_H
#include<fstream>
#include<iostream>
#include<vector>
#include<map>
#include<set>
#include<cmath>
using namespace std;
const int DataRow=14;
const int DataColumn=6;
struct Node
{
	double value;//代表此時yes的概率。
	int attrid;
	Node * parentNode;
	vector<Node*> childNode;
};
#endif

程序源文件id3.cpp

#include "id3.h"

string DataTable[DataRow][DataColumn];
map<string,int> str2int;
set<int> S;
set<int> Attributes;
string attrName[DataColumn]={"Day","Outlook","Temperature","Humidity","Wind","PlayTennis"};
string attrValue[DataColumn][DataRow]=
{
	{},//D1,D2這個屬性不需要
	{"Sunny","Overcast","Rain"},
	{"Hot","Mild","Cool"},
	{"High","Normal"},
	{"Weak","Strong"},
	{"No","Yes"}
};
int attrCount[DataColumn]={14,3,3,2,2,2};
double lg2(double n)
{
	return log(n)/log(2);
}
void Init()
{
	ifstream fin("data.txt");
	for(int i=0;i<14;i++)
	{
	  for(int j=0;j<6;j++)
	  {
		  fin>>DataTable[i][j];
	  }
	}
	fin.close();
	for(int i=1;i<=5;i++)
	{
		str2int[attrName[i]]=i;
		for(int j=0;j<attrCount[i];j++)
		{
			str2int[attrValue[i][j]]=j;
		}
	}
	for(int i=0;i<DataRow;i++)
	  S.insert(i);
	for(int i=1;i<=4;i++)
	  Attributes.insert(i);
}

double Entropy(const set<int> &s)
{
	double yes=0,no=0,sum=s.size(),ans=0;
	for(set<int>::iterator it=s.begin();it!=s.end();it++)
	{
		string s=DataTable[*it][str2int["PlayTennis"]];
		if(s=="Yes")
		  yes++;
		else
		  no++;
	}
	if(no==0||yes==0)
	  return ans=0;
	ans=-yes/sum*lg2(yes/sum)-no/sum*lg2(no/sum);
	return ans;
}
double Gain(const set<int> & example,int attrid)
{
	int attrcount=attrCount[attrid];
	double ans=Entropy(example);
	double sum=example.size();
	set<int> * pset=new set<int>[attrcount];
	for(set<int>::iterator it=example.begin();it!=example.end();it++)
	{
		pset[str2int[DataTable[*it][attrid]]].insert(*it);
	}
	for(int i=0;i<attrcount;i++)
	{
		ans-=pset[i].size()/sum*Entropy(pset[i]);
	}
	return ans;
}
int FindBestAttribute(const set<int> & example,const set<int> & attr)
{
	double mx=0;
	int k=-1;
	for(set<int>::iterator i=attr.begin();i!=attr.end();i++)
	{
		double ret=Gain(example,*i);
		if(ret>mx)
		{
			mx=ret;
			k=*i;
		}
	}
	if(k==-1)
	  cout<<"FindBestAttribute error!"<<endl;
	return k;
}
Node * Id3_solution(set<int> example,set<int> & attributes,Node * parent)
{
	Node *now=new Node;//創建樹節點。
	now->parentNode=parent;
	if(attributes.empty())//如果此時屬性列表已用完,即爲空,則返回。
	  return now;

	/*
	 * 統計一下example,如果都爲正或者都爲負則表示已經抵達決策樹的葉子節點
	 * 葉子節點的特徵是有childNode爲空。
	 */
	int yes=0,no=0,sum=example.size();
	for(set<int>::iterator it=example.begin();it!=example.end();it++)
	{
		string s=DataTable[*it][str2int["PlayTennis"]];
		if(s=="Yes")
		  yes++;
		else
		  no++;
	}
	if(yes==sum||yes==0)
	{
		now->value=yes/sum;
		return now;
	}
	

	/*找到最高信息增益的屬性並將該屬性從attributes集合中刪除*/
	int bestattrid=FindBestAttribute(example,attributes);
	now->attrid=bestattrid;
	attributes.erase(attributes.find(bestattrid));
	
	/*將exmple根據最佳屬性的不同屬性值分成幾個分支,每個分支有即一個子樹*/
	vector< set<int> > child=vector< set<int> >(attrCount[bestattrid]);
	for(set<int>::iterator i=example.begin();i!=example.end();i++)
	{
		int id=str2int[DataTable[*i][bestattrid]];
		child[id].insert(*i);
	}
	for(int i=0;i<child.size();i++)
	{
		Node * ret=Id3_solution(child[i],attributes,now);
		now->childNode.push_back(ret);
	}
	return now;
}

int main()
{
	Init();
	Node * Root=Id3_solution(S,Attributes,NULL);
	return 0;
}


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章