基於WEKA的決策樹(ID3)代碼實現

1.僞代碼

  1. 將所有屬性放入可選屬性集A中;
  2. 若訓練集D的所有樣本屬於同一類別k,則構建葉節點,節點的類標籤分佈爲[第k個爲1.0,其他爲0.0],返回該葉節點;
  3. 若A爲空,使用ClassDistribution獲得類標籤分佈,以此構建葉節點,並返回葉節點;
  4. 否則,計算A中所有屬性的信息增益,找出信息增益最大的屬性Ag;
  5. 若Ag的信息增益小於閾值,則使用ClassDistribution獲得類標籤分佈,以此構建葉節點,並返回葉節點;
  6. 否則,基於當前訓練集構建節點,然後基於Ag的取值個數N將當前訓練集進行劃分成N個子集,對第i個非空子集,以A-Ag爲可選屬性集,遞歸1-5得到所有子樹,返回該節點及其子樹。

2.代碼

package weka.classifiers.xwq;

import java.util.ArrayList;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;

public class ID3 extends Classifier
{
	/**
	 * 訓練數據集
	 */
	private Instances m_Train = null;
	/**
	 * 基於訓練集構建的決策樹
	 */
	private TreeNodes m_Tree = null;
	
	/**
	 * 屬性值個數
	 */
	private int m_NumAttributes = -1;
	/**
	 * 可選屬性的集合
	 */
	private ArrayList<Attribute> m_AttributesOptions = new ArrayList<>();
	/**
	 * 類標籤的個數
	 */
	private int m_NumClassValues = -1;
	/**
	 * 閾值
	 */
	private double m_Threshold = 0.1;
	
	@Override
	public void buildClassifier(Instances data) throws Exception
	{
		// TODO Auto-generated method stub
		m_Train = new Instances(data);
		
		m_NumAttributes = m_Train.numAttributes();
		for (int i = 0; i < m_NumAttributes; i++)
			if (i!=m_Train.classIndex())
				m_AttributesOptions.add(m_Train.attribute(i));
		
		m_NumClassValues = data.numClasses();
		m_Tree = BuildTree(m_Train);
	}
	
	private TreeNodes BuildTree(Instances data)
	{
		// TODO Auto-generated method stub
				if (SameClassValue(data))//數據集中所有實例屬於同一類標籤時,構建葉節點,類標籤分佈由該類標籤決定
		{
			double []classDistribution = new double[m_NumClassValues];
			classDistribution[(int)data.instance(0).classValue()] = 1.0;
			return new TreeNodes(data,classDistribution);
		}			
		if (m_AttributesOptions.size() == 0)//可選屬性爲空時,以當前數據集構建葉節點,類標籤分佈由數據集決定
		{
			double[] classDistribution = ClassDistribution(data);
			return new TreeNodes(data, classDistribution);
		}
		TreeNodes root = null;
		double gain[] = computeInformationGain(data);
		int selectedAttIndex = Utils.maxIndex(gain);
		if (gain[selectedAttIndex] < m_Threshold)//最大信息增益小於閾值時,構建葉節點,類標籤分佈由數據集決定
		{
			double[] classDistribution = ClassDistribution(data);
			return new TreeNodes(data, classDistribution);
		}
		else
		{
			Attribute selectedAttribute = m_AttributesOptions.get(selectedAttIndex);
			m_AttributesOptions.remove(selectedAttIndex);
			root = new TreeNodes(data, ClassDistribution(data), selectedAttribute);//以所選屬性構建非葉節點,類標籤分佈由數據集決定
			Instances subData[] = SplitDataByAttribute(data,selectedAttribute);
			for (int i = 0; i < subData.length; i++)
				if (subData[i].numInstances() != 0 )
					root.m_SubTree[i] = BuildTree(subData[i]);
		}
		
		return root;
	}
/**
 * Split the data according to attribute
 * @param data
 * @param attribute
 * @return
 */
	private Instances[] SplitDataByAttribute(Instances data, Attribute attribute)
	{
		// TODO Auto-generated method stub
		Instances subInstances[] = new Instances[attribute.numValues()];
		for (int i = 0; i < subInstances.length; i++)
			subInstances[i] = new Instances(data,0);
		
		for (int i = 0; i < data.numInstances(); i++)
		{
			Instance instance = data.instance(i);
			int attValue = (int)instance.value(attribute);
			subInstances[attValue].add(instance);
		}
		
		return subInstances;
	}

	/**
	 * compute the information gain of data
	 * @param data
	 * @return
	 */
	private double[] computeInformationGain(Instances data)
	{
		// TODO Auto-generated method stub
		double []gain = new double[m_AttributesOptions.size()];
		double entropy_before = calculateEntropy(data);
		
		for (int i = 0; i < gain.length; i++)
		{
			Attribute attribute = m_AttributesOptions.get(i);
			Instances []subInstances = SplitDataByAttribute(data, attribute);
			double []entropies = new double[attribute.numValues()];
			for (int j = 0; j < entropies.length; j++)
				if (subInstances[j].numInstances() != 0)
					entropies[j] = calculateEntropy(subInstances[j]);
			
			gain[i] = entropy_before;
			for (int j = 0; j < entropies.length; j++)
				gain[i] -= ((double)subInstances[j].numInstances() / (double)data.numInstances())*entropies[j];
		}
		
		return gain;
	}

	/**
	 * calculate entropy of data
	 * @param data
	 * @return
	 */
	private double calculateEntropy(Instances data)
	{
		// TODO Auto-generated method stub
		double []count = new double[m_NumClassValues];
		for (int i = 0; i < data.numInstances(); i++)
		{
			int classValue = (int)data.instance(i).classValue();
			count[classValue] ++;
		}
		
		double entropy = 0.0;
		for (int i = 0; i < count.length; i++)
			entropy += count[i]/(double)data.numInstances() * log2((double)count[i],(double)data.numInstances());
		
		return -entropy;
	}
	  private double log2(double x,double y){

		    if(x<1e-6||y<1e-6)
		      return 0.0;
		    else
		      return Math.log(x/y)/Math.log(2);
		  }
	/**
	 * return the ClassDistribution of data 
	 * @param data
	 * @return
	 */
	private double[] ClassDistribution(Instances data)
	{
		// TODO Auto-generated method stub
		double count[] = new double[m_NumClassValues];
		for (int i = 0; i < data.numInstances(); i++)
		{
			Instance instance = data.instance(i);
			int classValue = (int)instance.classValue();
			count[classValue] ++;
		}
		
		for (int i = 0; i < count.length; i++)
			count[i] = count[i] / data.numInstances();
		
		return count;
	}

	/**
	 * judge the instance in data is all the same class value
	 * @param data
	 * @return
	 */
	private boolean SameClassValue(Instances data)
	{
		// TODO Auto-generated method stub
		if (data.numInstances() == 0)
			return true;
		
		int sameClassValue = (int)data.instance(0).classValue();
		for (int i = 1; i < data.numInstances(); i++)
			if ((int)data.instance(i).classValue() != sameClassValue)
				return false;
		
		return true;
	}

	public double[] distributionForInstance(Instance instance) 
	{
		
		TreeNodes current = m_Tree;
		while (current.m_SubTree!=null)//根據待測樣本找到決策樹的葉節點
		{
			Attribute attribute = current.m_Attribute;
			int attValue = (int)instance.value(attribute);
			current = current.m_SubTree[attValue];
		}
		
		return current.m_ClassDistribution;//返回葉節點的類標籤分佈即可
	}
	
	  public static void main(String[] args) {

		    try {
		      System.out.println(Evaluation.evaluateModel(new ID3(), args));
		    } catch (Exception e) {
		      System.err.println(e.getMessage());
		    }
		  }
	  
	  //java.io.NotSerializableException will triggered if not implements java.io.Serializable
	  private class TreeNodes implements java.io.Serializable
	  {
	  	/**
	  	 * 節點的類標籤分佈
	  	 */
	  	public double m_ClassDistribution[] = null;
	  	/**
	  	 * 節點對應的數據集
	  	 */
	  	public Instances m_SubTrain = null;
	  	/**
	  	 * 節點對應的選擇屬性
	  	 */
	  	public Attribute m_Attribute = null;
	  	/**
	  	 * 節點對應的多個子樹
	  	 */
	  	public TreeNodes m_SubTree[] = null;
	  	
	  	public TreeNodes(Instances data,double[] value)
	  	{
	  		this.m_SubTrain = new Instances(data);
	  		this.m_ClassDistribution = value;
	  	}
	  	public TreeNodes(Instances data,double[] value,Attribute attribute)
	  	{
	  		this.m_SubTrain = new Instances(data);
	  		this.m_ClassDistribution = value;
	  		this.m_Attribute = attribute;
	  		this.m_SubTree = new TreeNodes[this.m_Attribute.numValues()];
	  	}
}
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章