1.僞代碼
- 將所有屬性放入可選屬性集A中;
- 若訓練集D的所有樣本屬於同一類別k,則構建葉節點,節點的類標籤分佈爲[第k個爲1.0,其他爲0.0],返回該葉節點;
- 若A爲空,使用ClassDistribution獲得類標籤分佈,以此構建葉節點,並返回葉節點;
- 否則,計算A中所有屬性的信息增益,找出信息增益最大的屬性Ag;
- 若Ag的信息增益小於閾值,則使用ClassDistribution獲得類標籤分佈,以此構建葉節點,並返回葉節點;
- 否則,基於當前訓練集構建節點,然後基於Ag的取值個數N將當前訓練集進行劃分成N個子集,對第i個非空子集,以A-Ag爲可選屬性集,遞歸1-5得到所有子樹,返回該節點及其子樹。
2.代碼
package weka.classifiers.xwq;
import java.util.ArrayList;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
public class ID3 extends Classifier
{
/**
* 訓練數據集
*/
private Instances m_Train = null;
/**
* 基於訓練集構建的決策樹
*/
private TreeNodes m_Tree = null;
/**
* 屬性值個數
*/
private int m_NumAttributes = -1;
/**
* 可選屬性的集合
*/
private ArrayList<Attribute> m_AttributesOptions = new ArrayList<>();
/**
* 類標籤的個數
*/
private int m_NumClassValues = -1;
/**
* 閾值
*/
private double m_Threshold = 0.1;
@Override
public void buildClassifier(Instances data) throws Exception
{
// TODO Auto-generated method stub
m_Train = new Instances(data);
m_NumAttributes = m_Train.numAttributes();
for (int i = 0; i < m_NumAttributes; i++)
if (i!=m_Train.classIndex())
m_AttributesOptions.add(m_Train.attribute(i));
m_NumClassValues = data.numClasses();
m_Tree = BuildTree(m_Train);
}
private TreeNodes BuildTree(Instances data)
{
// TODO Auto-generated method stub
if (SameClassValue(data))//數據集中所有實例屬於同一類標籤時,構建葉節點,類標籤分佈由該類標籤決定
{
double []classDistribution = new double[m_NumClassValues];
classDistribution[(int)data.instance(0).classValue()] = 1.0;
return new TreeNodes(data,classDistribution);
}
if (m_AttributesOptions.size() == 0)//可選屬性爲空時,以當前數據集構建葉節點,類標籤分佈由數據集決定
{
double[] classDistribution = ClassDistribution(data);
return new TreeNodes(data, classDistribution);
}
TreeNodes root = null;
double gain[] = computeInformationGain(data);
int selectedAttIndex = Utils.maxIndex(gain);
if (gain[selectedAttIndex] < m_Threshold)//最大信息增益小於閾值時,構建葉節點,類標籤分佈由數據集決定
{
double[] classDistribution = ClassDistribution(data);
return new TreeNodes(data, classDistribution);
}
else
{
Attribute selectedAttribute = m_AttributesOptions.get(selectedAttIndex);
m_AttributesOptions.remove(selectedAttIndex);
root = new TreeNodes(data, ClassDistribution(data), selectedAttribute);//以所選屬性構建非葉節點,類標籤分佈由數據集決定
Instances subData[] = SplitDataByAttribute(data,selectedAttribute);
for (int i = 0; i < subData.length; i++)
if (subData[i].numInstances() != 0 )
root.m_SubTree[i] = BuildTree(subData[i]);
}
return root;
}
/**
* Split the data according to attribute
* @param data
* @param attribute
* @return
*/
private Instances[] SplitDataByAttribute(Instances data, Attribute attribute)
{
// TODO Auto-generated method stub
Instances subInstances[] = new Instances[attribute.numValues()];
for (int i = 0; i < subInstances.length; i++)
subInstances[i] = new Instances(data,0);
for (int i = 0; i < data.numInstances(); i++)
{
Instance instance = data.instance(i);
int attValue = (int)instance.value(attribute);
subInstances[attValue].add(instance);
}
return subInstances;
}
/**
* compute the information gain of data
* @param data
* @return
*/
private double[] computeInformationGain(Instances data)
{
// TODO Auto-generated method stub
double []gain = new double[m_AttributesOptions.size()];
double entropy_before = calculateEntropy(data);
for (int i = 0; i < gain.length; i++)
{
Attribute attribute = m_AttributesOptions.get(i);
Instances []subInstances = SplitDataByAttribute(data, attribute);
double []entropies = new double[attribute.numValues()];
for (int j = 0; j < entropies.length; j++)
if (subInstances[j].numInstances() != 0)
entropies[j] = calculateEntropy(subInstances[j]);
gain[i] = entropy_before;
for (int j = 0; j < entropies.length; j++)
gain[i] -= ((double)subInstances[j].numInstances() / (double)data.numInstances())*entropies[j];
}
return gain;
}
/**
* calculate entropy of data
* @param data
* @return
*/
private double calculateEntropy(Instances data)
{
// TODO Auto-generated method stub
double []count = new double[m_NumClassValues];
for (int i = 0; i < data.numInstances(); i++)
{
int classValue = (int)data.instance(i).classValue();
count[classValue] ++;
}
double entropy = 0.0;
for (int i = 0; i < count.length; i++)
entropy += count[i]/(double)data.numInstances() * log2((double)count[i],(double)data.numInstances());
return -entropy;
}
private double log2(double x,double y){
if(x<1e-6||y<1e-6)
return 0.0;
else
return Math.log(x/y)/Math.log(2);
}
/**
* return the ClassDistribution of data
* @param data
* @return
*/
private double[] ClassDistribution(Instances data)
{
// TODO Auto-generated method stub
double count[] = new double[m_NumClassValues];
for (int i = 0; i < data.numInstances(); i++)
{
Instance instance = data.instance(i);
int classValue = (int)instance.classValue();
count[classValue] ++;
}
for (int i = 0; i < count.length; i++)
count[i] = count[i] / data.numInstances();
return count;
}
/**
* judge the instance in data is all the same class value
* @param data
* @return
*/
private boolean SameClassValue(Instances data)
{
// TODO Auto-generated method stub
if (data.numInstances() == 0)
return true;
int sameClassValue = (int)data.instance(0).classValue();
for (int i = 1; i < data.numInstances(); i++)
if ((int)data.instance(i).classValue() != sameClassValue)
return false;
return true;
}
public double[] distributionForInstance(Instance instance)
{
TreeNodes current = m_Tree;
while (current.m_SubTree!=null)//根據待測樣本找到決策樹的葉節點
{
Attribute attribute = current.m_Attribute;
int attValue = (int)instance.value(attribute);
current = current.m_SubTree[attValue];
}
return current.m_ClassDistribution;//返回葉節點的類標籤分佈即可
}
public static void main(String[] args) {
try {
System.out.println(Evaluation.evaluateModel(new ID3(), args));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
//java.io.NotSerializableException will triggered if not implements java.io.Serializable
private class TreeNodes implements java.io.Serializable
{
/**
* 節點的類標籤分佈
*/
public double m_ClassDistribution[] = null;
/**
* 節點對應的數據集
*/
public Instances m_SubTrain = null;
/**
* 節點對應的選擇屬性
*/
public Attribute m_Attribute = null;
/**
* 節點對應的多個子樹
*/
public TreeNodes m_SubTree[] = null;
public TreeNodes(Instances data,double[] value)
{
this.m_SubTrain = new Instances(data);
this.m_ClassDistribution = value;
}
public TreeNodes(Instances data,double[] value,Attribute attribute)
{
this.m_SubTrain = new Instances(data);
this.m_ClassDistribution = value;
this.m_Attribute = attribute;
this.m_SubTree = new TreeNodes[this.m_Attribute.numValues()];
}
}
}