分類迴歸樹(CART,Classification And Regression Tree)也屬於一種決策樹,上回文我們介紹了基於ID3算法的決策樹。作爲上篇,這裏只介紹CART是怎樣用於分類的。
分類迴歸樹是一棵二叉樹,且每個非葉子節點都有兩個孩子,所以對於第一棵子樹其葉子節點數比非葉子節點數多1。
表1
名稱 | 體溫 | 表面覆蓋 | 胎生 | 產蛋 | 能飛 | 水生 | 有腿 | 冬眠 | 類標記 |
人 | 恆溫 | 毛髮 | 是 | 否 | 否 | 否 | 是 | 否 | 哺乳類 |
巨蟒 | 冷血 | 鱗片 | 否 | 是 | 否 | 否 | 否 | 是 | 爬行類 |
鮭魚 | 冷血 | 鱗片 | 否 | 是 | 否 | 是 | 否 | 否 | 魚類 |
鯨 | 恆溫 | 毛髮 | 是 | 否 | 否 | 是 | 否 | 否 | 哺乳類 |
蛙 | 冷血 | 無 | 否 | 是 | 否 | 有時 | 是 | 是 | 兩棲類 |
巨蜥 | 冷血 | 鱗片 | 否 | 是 | 否 | 否 | 是 | 否 | 爬行類 |
蝙蝠 | 恆溫 | 毛髮 | 是 | 否 | 是 | 否 | 是 | 否 | 哺乳類 |
貓 | 恆溫 | 皮 | 是 | 否 | 否 | 否 | 是 | 否 | 哺乳類 |
豹紋鯊 | 冷血 | 鱗片 | 是 | 否 | 否 | 是 | 否 | 否 | 魚類 |
海龜 | 冷血 | 鱗片 | 否 | 是 | 否 | 有時 | 是 | 否 | 爬行類 |
豪豬 | 恆溫 | 剛毛 | 是 | 否 | 否 | 否 | 是 | 是 | 哺乳類 |
鰻 | 冷血 | 鱗片 | 否 | 是 | 否 | 是 | 否 | 否 | 魚類 |
蠑螈 | 冷血 | 無 | 否 | 是 | 否 | 有時 | 是 | 是 | 兩棲類 |
上例是屬性有8個,每個屬性又有多少離散的值可取。在決策樹的每一個節點上我們可以按任一個屬性的任一個值進行劃分。比如最開始我們按:
1)表面覆蓋爲毛髮和非毛髮
2)表面覆蓋爲鱗片和非鱗片
3)體溫爲恆溫和非恆溫
等等產生當前節點的左右兩個孩子。按哪種劃分最好呢?有3個標準可以用來衡量劃分的好壞:GINI指數、雙化指數、有序雙化指數。下面我們只講GINI指數。
GINI指數
總體內包含的類別越雜亂,GINI指數就越大(跟熵的概念很相似)。比如體溫爲恆溫時包含哺乳類5個、鳥類2個,則:
體溫爲非恆溫時包含爬行類3個、魚類3個、兩棲類2個,則
所以如果按照“體溫爲恆溫和非恆溫”進行劃分的話,我們得到GINI的增益(類比信息增益):
最好的劃分就是使得GINI_Gain最小的劃分。
終止條件
一個節點產生左右孩子後,遞歸地對左右孩子進行劃分即可產生分類迴歸樹。這裏的終止條件是什麼?什麼時候節點就可以停止分裂了?直觀的情況,當節點包含的數據記錄都屬於同一個類別時就可以終止分裂了。這只是一個特例,更一般的情況我們計算χ2值來判斷分類條件和類別的相關程度,當χ2很小時說明分類條件和類別是獨立的,即按照該分類條件進行分類是沒有道理的,此時節點停止分裂。注意這裏的“分類條件”是指按照GINI_Gain最小原則得到的“分類條件”。
假如在構造分類迴歸樹的第一步我們得到的“分類條件”是:體溫爲恆溫和非恆溫。此時:
哺乳類 | 爬行類 | 魚類 | 鳥類 | 兩棲類 | |
恆溫 | 5 | 0 | 0 | 2 | 0 |
非恆溫 | 0 | 3 | 3 | 0 | 2 |
我在《獨立性檢驗》中講述了χ2的計算方法。當選定置信水平後查表可得“體溫”與動物類別是否相互獨立。
還有一種方式就是,如果某一分支覆蓋的樣本的個數如果小於一個閾值,那麼也可產生葉子節點,從而終止Tree-Growth。
剪枝
當分類迴歸樹劃分得太細時,會對噪聲數據產生過擬合作用。因此我們要通過剪枝來解決。剪枝又分爲前剪枝和後剪枝:前剪枝是指在構造樹的過程中就知道哪些節點可以剪掉,於是乾脆不對這些節點進行分裂,在N皇后問題和揹包問題中用的都是前剪枝,上面的χ2方法也可以認爲是一種前剪枝;後剪枝是指構造出完整的決策樹之後再來考查哪些子樹可以剪掉。
在分類迴歸樹中可以使用的後剪枝方法有多種,比如:代價複雜性剪枝、最小誤差剪枝、悲觀誤差剪枝等等。這裏我們只介紹代價複雜性剪枝法。
對於分類迴歸樹中的每一個非葉子節點計算它的表面誤差率增益值α。
是子樹中包含的葉子節點個數;
是節點t的誤差代價,如果該節點被剪枝;
r(t)是節點t的誤差率;
p(t)是節點t上的數據佔所有數據的比例。
是子樹Tt的誤差代價,如果該節點不被剪枝。它等於子樹Tt上所有葉子節點的誤差代價之和。
比如有個非葉子節點t4如圖所示:
已知所有的數據總共有60條,則節點t4的節點誤差代價爲:
子樹誤差代價爲:
以t4爲根節點的子樹上葉子節點有3個,最終:
找到α值最小的非葉子節點,令其左右孩子爲NULL。當多個非葉子節點的α值同時達到最小時,取最大的進行剪枝。
源代碼。拿表1作爲訓練數據,得到剪枝前和剪枝後的兩棵分類迴歸樹,再對錶1中的數據進行分類測試。
#include<iostream>
#include<fstream>
#include<sstream>
#include<string>
#include<map>
#include<list>
#include<set>
#include<queue>
#include<utility>
#include<vector>
#include<cmath>
using namespace std;
//置信水平取0.95時的卡方表
const double CHI[18]={0.004,0.103,0.352,0.711,1.145,1.635,2.167,2.733,3.325,3.94,4.575,5.226,5.892,6.571,7.261,7.962};
/*根據多維數組計算卡方值*/
template<typename Comparable>
double cal_chi(Comparable **arr,int row,int col){
vector<Comparable> rowsum(row);
vector<Comparable> colsum(col);
Comparable totalsum=static_cast<Comparable>(0);
//cout<<"observation"<<endl;
for(int i=0;i<row;++i){
for(int j=0;j<col;++j){
//cout<<arr[i][j]<<"\t";
totalsum+=arr[i][j];
rowsum[i]+=arr[i][j];
colsum[j]+=arr[i][j];
}
//cout<<endl;
}
double rect=0.0;
//cout<<"exception"<<endl;
for(int i=0;i<row;++i){
for(int j=0;j<col;++j){
double excep=1.0*rowsum[i]*colsum[j]/totalsum;
//cout<<excep<<"\t";
if(excep!=0)
rect+=pow(arr[i][j]-excep,2.0)/excep;
}
//cout<<endl;
}
return rect;
}
class MyTriple{
public:
double first;
int second;
int third;
MyTriple(){
first=0.0;
second=0;
third=0;
}
MyTriple(double f,int s,int t):first(f),second(s),third(t){}
bool operator< (const MyTriple &obj) const{
int cmp=this->first-obj.first;
if(cmp>0)
return false;
else if(cmp<0)
return true;
else{
cmp=obj.second-this->second;
if(cmp<0)
return true;
else
return false;
}
}
};
typedef map<string,int> MAP_REST_COUNT;
typedef map<string,MAP_REST_COUNT> MAP_ATTR_REST;
typedef vector<MAP_ATTR_REST> VEC_STATI;
const int ATTR_NUM=8; //自變量的維度
vector<string> X(ATTR_NUM);
int rest_number; //因變量的種類數,即類別數
vector<pair<string,int> > classes; //把類別、對應的記錄數存放在一個數組中
int total_record_number; //總的記錄數
vector<vector<string> > inputData; //原始輸入數據
class node{
public:
node* parent; //父節點
node* leftchild; //左孩子節點
node* rightchild; //右孩子節點
string cond; //分枝條件
string decision; //在該節點上作出的類別判定
double precision; //判定的正確率
int record_number; //該節點上涵蓋的記錄個數
int size; //子樹包含的葉子節點的數目
int index; //層次遍歷樹,給節點標上序號
double alpha; //表面誤差率的增加量
node(){
parent=NULL;
leftchild=NULL;
rightchild=NULL;
precision=0.0;
record_number=0;
size=1;
index=0;
alpha=1.0;
}
node(node* p){
parent=p;
leftchild=NULL;
rightchild=NULL;
precision=0.0;
record_number=0;
size=1;
index=0;
alpha=1.0;
}
node(node* p,string c,string d):cond(c),decision(d){
parent=p;
leftchild=NULL;
rightchild=NULL;
precision=0.0;
record_number=0;
size=1;
index=0;
alpha=1.0;
}
void printInfo(){
cout<<"index:"<<index<<"\tdecisoin:"<<decision<<"\tprecision:"<<precision<<"\tcondition:"<<cond<<"\tsize:"<<size;
if(parent!=NULL)
cout<<"\tparent index:"<<parent->index;
if(leftchild!=NULL)
cout<<"\tleftchild:"<<leftchild->index<<"\trightchild:"<<rightchild->index;
cout<<endl;
}
void printTree(){
printInfo();
if(leftchild!=NULL)
leftchild->printTree();
if(rightchild!=NULL)
rightchild->printTree();
}
};
int readInput(string filename){
ifstream ifs(filename.c_str());
if(!ifs){
cerr<<"open inputfile failed!"<<endl;
return -1;
}
map<string,int> catg;
string line;
getline(ifs,line);
string item;
istringstream strstm(line);
strstm>>item;
for(int i=0;i<X.size();++i){
strstm>>item;
X[i]=item;
}
while(getline(ifs,line)){
vector<string> conts(ATTR_NUM+2);
istringstream strstm(line);
//strstm.str(line);
for(int i=0;i<conts.size();++i){
strstm>>item;
conts[i]=item;
if(i==conts.size()-1)
catg[item]++;
}
inputData.push_back(conts);
}
total_record_number=inputData.size();
ifs.close();
map<string,int>::const_iterator itr=catg.begin();
while(itr!=catg.end()){
classes.push_back(make_pair(itr->first,itr->second));
itr++;
}
rest_number=classes.size();
return 0;
}
/*根據inputData作出一個統計stati*/
void statistic(vector<vector<string> > &inputData,VEC_STATI &stati){
for(int i=1;i<ATTR_NUM+1;++i){
MAP_ATTR_REST attr_rest;
for(int j=0;j<inputData.size();++j){
string attr_value=inputData[j][i];
string rest=inputData[j][ATTR_NUM+1];
MAP_ATTR_REST::iterator itr=attr_rest.find(attr_value);
if(itr==attr_rest.end()){
MAP_REST_COUNT rest_count;
rest_count[rest]=1;
attr_rest[attr_value]=rest_count;
}
else{
MAP_REST_COUNT::iterator iter=(itr->second).find(rest);
if(iter==(itr->second).end()){
(itr->second).insert(make_pair(rest,1));
}
else{
iter->second+=1;
}
}
}
stati.push_back(attr_rest);
}
}
/*依據某條件作出分枝時,inputData被分成兩部分*/
void splitInput(vector<vector<string> > &inputData,int fitIndex,string cond,vector<vector<string> > &LinputData,vector<vector<string> > &RinputData){
for(int i=0;i<inputData.size();++i){
if(inputData[i][fitIndex+1]==cond)
LinputData.push_back(inputData[i]);
else
RinputData.push_back(inputData[i]);
}
}
void printStati(VEC_STATI &stati){
for(int i=0;i<stati.size();i++){
MAP_ATTR_REST::const_iterator itr=stati[i].begin();
while(itr!=stati[i].end()){
cout<<itr->first;
MAP_REST_COUNT::const_iterator iter=(itr->second).begin();
while(iter!=(itr->second).end()){
cout<<"\t"<<iter->first<<"\t"<<iter->second;
iter++;
}
itr++;
cout<<endl;
}
cout<<endl;
}
}
void split(node *root,vector<vector<string> > &inputData,vector<pair<string,int> > classes){
//root->printInfo();
root->record_number=inputData.size();
VEC_STATI stati;
statistic(inputData,stati);
//printStati(stati);
//for(int i=0;i<rest_number;i++)
// cout<<classes[i].first<<"\t"<<classes[i].second<<"\t";
//cout<<endl;
/*找到最大化GINI指標的劃分*/
double minGain=1.0; //最小的GINI增益
int fitIndex=-1;
string fitCond;
vector<pair<string,int> > fitleftclasses;
vector<pair<string,int> > fitrightclasses;
int fitleftnumber;
int fitrightnumber;
for(int i=0;i<stati.size();++i){ //掃描每一個自變量
MAP_ATTR_REST::const_iterator itr=stati[i].begin();
while(itr!=stati[i].end()){ //掃描自變量上的每一個取值
string condition=itr->first; //判定的條件,即到達左孩子的條件
//cout<<"cond 爲"<<X[i]+condition<<"時:";
vector<pair<string,int> > leftclasses(classes); //左孩子節點上類別、及對應的數目
vector<pair<string,int> > rightclasses(classes); //右孩子節點上類別、及對應的數目
int leftnumber=0; //左孩子節點上包含的類別數目
int rightnumber=0; //右孩子節點上包含的類別數目
for(int j=0;j<leftclasses.size();++j){ //更新類別對應的數目
string rest=leftclasses[j].first;
MAP_REST_COUNT::const_iterator iter2;
iter2=(itr->second).find(rest);
if(iter2==(itr->second).end()){ //沒找到
leftclasses[j].second=0;
rightnumber+=rightclasses[j].second;
}
else{ //找到
leftclasses[j].second=iter2->second;
leftnumber+=leftclasses[j].second;
rightclasses[j].second-=(iter2->second);
rightnumber+=rightclasses[j].second;
}
}
/**if(leftnumber==0 || rightnumber==0){
cout<<"左右有一邊爲空"<<endl;
for(int k=0;k<rest_number;k++)
cout<<leftclasses[k].first<<"\t"<<leftclasses[k].second<<"\t";
cout<<endl;
for(int k=0;k<rest_number;k++)
cout<<rightclasses[k].first<<"\t"<<rightclasses[k].second<<"\t";
cout<<endl;
}**/
double gain1=1.0; //計算GINI增益
double gain2=1.0;
if(leftnumber==0)
gain1=0.0;
else
for(int j=0;j<leftclasses.size();++j)
gain1-=pow(1.0*leftclasses[j].second/leftnumber,2.0);
if(rightnumber==0)
gain2=0.0;
else
for(int j=0;j<rightclasses.size();++j)
gain2-=pow(1.0*rightclasses[j].second/rightnumber,2.0);
double gain=1.0*leftnumber/(leftnumber+rightnumber)*gain1+1.0*rightnumber/(leftnumber+rightnumber)*gain2;
//cout<<"GINI增益:"<<gain<<endl;
if(gain<minGain){
//cout<<"GINI增益:"<<gain<<"\t"<<i<<"\t"<<condition<<endl;
fitIndex=i;
fitCond=condition;
fitleftclasses=leftclasses;
fitrightclasses=rightclasses;
fitleftnumber=leftnumber;
fitrightnumber=rightnumber;
minGain=gain;
}
itr++;
}
}
/*計算卡方值,看有沒有必要進行分裂*/
//cout<<"按"<<X[fitIndex]+fitCond<<"劃分,計算卡方"<<endl;
int **arr=new int*[2];
for(int i=0;i<2;i++)
arr[i]=new int[rest_number];
for(int i=0;i<rest_number;i++){
arr[0][i]=fitleftclasses[i].second;
arr[1][i]=fitrightclasses[i].second;
}
double chi=cal_chi(arr,2,rest_number);
//cout<<"chi="<<chi<<" CHI="<<CHI[rest_number-2]<<endl;
if(chi<CHI[rest_number-2]){ //獨立,沒必要再分裂了
delete []arr[0]; delete []arr[1]; delete []arr;
return; //不需要分裂函數就返回
}
delete []arr[0]; delete []arr[1]; delete []arr;
/*分裂*/
root->cond=X[fitIndex]+"="+fitCond; //root的分枝條件
//cout<<"分類條件:"<<root->cond<<endl;
node *travel=root; //root及其祖先節點的size都要加1
while(travel!=NULL){
(travel->size)++;
travel=travel->parent;
}
node *LChild=new node(root); //創建左右孩子
node *RChild=new node(root);
root->leftchild=LChild;
root->rightchild=RChild;
int maxLcount=0;
int maxRcount=0;
string Ldicision,Rdicision;
for(int i=0;i<rest_number;++i){ //統計哪種類別出現的最多,從而作出類別判定
if(fitleftclasses[i].second>maxLcount){
maxLcount=fitleftclasses[i].second;
Ldicision=fitleftclasses[i].first;
}
if(fitrightclasses[i].second>maxRcount){
maxRcount=fitrightclasses[i].second;
Rdicision=fitrightclasses[i].first;
}
}
LChild->decision=Ldicision;
RChild->decision=Rdicision;
LChild->precision=1.0*maxLcount/fitleftnumber;
RChild->precision=1.0*maxRcount/fitrightnumber;
/*遞歸對左右孩子進行分裂*/
vector<vector<string> > LinputData,RinputData;
splitInput(inputData,fitIndex,fitCond,LinputData,RinputData);
//cout<<"左邊inputData行數:"<<LinputData.size()<<endl;
//cout<<"右邊inputData行數:"<<RinputData.size()<<endl;
split(LChild,LinputData,fitleftclasses);
split(RChild,RinputData,fitrightclasses);
}
/*計算子樹的誤差代價*/
double calR2(node *root){
if(root->leftchild==NULL)
return (1-root->precision)*root->record_number/total_record_number;
else
return calR2(root->leftchild)+calR2(root->rightchild);
}
/*層次遍歷樹,給節點標上序號。同時計算alpha*/
void index(node *root,priority_queue<MyTriple> &pq){
int i=1;
queue<node*> que;
que.push(root);
while(!que.empty()){
node* n=que.front();
que.pop();
n->index=i++;
if(n->leftchild!=NULL){
que.push(n->leftchild);
que.push(n->rightchild);
//計算表面誤差率的增量
double r1=(1-n->precision)*n->record_number/total_record_number; //節點的誤差代價
double r2=calR2(n);
n->alpha=(r1-r2)/(n->size-1);
pq.push(MyTriple(n->alpha,n->size,n->index));
}
}
}
/*剪枝*/
void prune(node *root,priority_queue<MyTriple> &pq){
MyTriple triple=pq.top();
int i=triple.third;
queue<node*> que;
que.push(root);
while(!que.empty()){
node* n=que.front();
que.pop();
if(n->index==i){
cout<<"將要剪掉"<<i<<"的左右子樹"<<endl;
n->leftchild=NULL;
n->rightchild=NULL;
int s=n->size-1;
node *trav=n;
while(trav!=NULL){
trav->size-=s;
trav=trav->parent;
}
break;
}
else if(n->leftchild!=NULL){
que.push(n->leftchild);
que.push(n->rightchild);
}
}
}
void test(string filename,node *root){
ifstream ifs(filename.c_str());
if(!ifs){
cerr<<"open inputfile failed!"<<endl;
return;
}
string line;
getline(ifs,line);
string item;
istringstream strstm(line); //跳過第一行
map<string,string> independent; //自變量,即分類的依據
while(getline(ifs,line)){
istringstream strstm(line);
//strstm.str(line);
strstm>>item;
cout<<item<<"\t";
for(int i=0;i<ATTR_NUM;++i){
strstm>>item;
independent[X[i]]=item;
}
node *trav=root;
while(trav!=NULL){
if(trav->leftchild==NULL){
cout<<(trav->decision)<<"\t置信度:"<<(trav->precision)<<endl;;
break;
}
string cond=trav->cond;
string::size_type pos=cond.find("=");
string pre=cond.substr(0,pos);
string post=cond.substr(pos+1);
if(independent[pre]==post)
trav=trav->leftchild;
else
trav=trav->rightchild;
}
}
ifs.close();
}
int main(){
string inputFile="animal";
readInput(inputFile);
VEC_STATI stati; //最原始的統計
statistic(inputData,stati);
// for(int i=0;i<classes.size();++i)
// cout<<classes[i].first<<"\t"<<classes[i].second<<"\t";
// cout<<endl;
node *root=new node();
split(root,inputData,classes); //分裂根節點
priority_queue<MyTriple> pq;
index(root,pq);
root->printTree();
cout<<"剪枝前使用該決策樹最多進行"<<root->size-1<<"次條件判斷"<<endl;
/**
//檢驗一個是不是表面誤差增量最小的被剪掉了
while(!pq.empty()){
MyTriple triple=pq.top();
pq.pop();
cout<<triple.first<<"\t"<<triple.second<<"\t"<<triple.third<<endl;
}
**/
test(inputFile,root);
prune(root,pq);
cout<<"剪枝後使用該決策樹最多進行"<<root->size-1<<"次條件判斷"<<endl;
test(inputFile,root);
return 0;
}
C4.5克服了ID3的2個缺點:
1.用信息增益選擇屬性時偏向於選擇分枝比較多的屬性值,即取值多的屬性
2.不能處理連貫屬性
詳細可參考這篇博客。