說明
參考文章-歸納決策樹ID3(Java實現),完成代碼編寫。
在原代碼的基礎上補充了預測函數,實現利用模型對新數據進行分類預測。
作者對ID3決策樹的介紹-ID3決策樹
決策樹採用xml文件保存,使用Dom4J類庫,點擊下載
讓Dom4J支持按XPath選擇節點,還得引入包jaxen.jar,點擊下載
源代碼彙總,點擊下載
思路
代碼
輸入文件採用ARFF格式,使用的訓練數據文件如下:
train.arff
@relation weather.symbolic
@attribute outlook {sunny,overcast,rainy}
@attribute temperature {hot,mild,cool}
@attribute humidity {high,normal}
@attribute windy {TRUE,FALSE}
@attribute play {yes,no}
@data
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no
ARFF(Attribute-Relation File Format):格式簡單明瞭,分爲兩部分,第一部分交代屬性及取值範圍,第二部分則是數據部分(data)。
由於只是測試代碼效果,測試集(predict.arff)也是上述數據,只是將類標相關的數據移除了。
ID3類
package ID3;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.Character.Subset;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.dom4j.Document;
import org.dom4j.DocumentHelper;
import org.dom4j.Element;
import org.dom4j.io.OutputFormat;
import org.dom4j.io.XMLWriter;
import org.w3c.dom.NodeList;
public class ID3 {
// 同時保留訓練集和測試集的數據在模型中,防止訓練集和測試集的列順序不同
private ArrayList<String> trainAttribute = new ArrayList<String>(); // 存儲訓練集屬性的名稱
private ArrayList<ArrayList<String>> train_attributeValue = new ArrayList<ArrayList<String>>(); // 存儲訓練集每個屬性的取值
private ArrayList<String> predictAttribute = new ArrayList<String>(); // 存儲測試集屬性的名稱
private ArrayList<ArrayList<String>> predict_attributeValue = new ArrayList<ArrayList<String>>(); // 存儲測試集每個屬性的取值
private ArrayList<String[]> train_data = new ArrayList<String[]>(); // 訓練集數據 ,即arff文件中的data字符串
private ArrayList<String[]> predict_data = new ArrayList<String[]>(); // 測試集數據
private String[] preLable;
int decatt; // 決策變量在屬性集中的索引(即類標所在列)
public static final String patternString = "@attribute(.*)[{](.*?)[}]";
//正則表達,其中*? 表示重複任意次,但儘可能少重複,防止匹配到更後面的"}"符號
Document xmldoc;
Element root;
public ID3() {
//創建並初始化xml文件,以用於儲存決策樹結構
xmldoc = DocumentHelper.createDocument();
root = xmldoc.addElement("root");
root.addElement("DecisionTree").addAttribute("value", "null");
}
/**
* 模型訓練函數
* @param class_name 類標變量
* @param data_pathname 訓練集
* @return xml決策樹文件
*/
public Document train(String class_name,String data_pathname){
read_trainARFF(new File(data_pathname));
setDec(class_name);
LinkedList<Integer> ll=new LinkedList<Integer>(); //LinkList用於增刪比ArrayList有優勢
for(int i=0;i<trainAttribute.size();i++){
if(i!=decatt) ll.add(i); //防止類別變量不在最後一列發生錯誤
}
ArrayList<Integer> al=new ArrayList<Integer>();
for(int i=0;i<train_data.size();i++){
al.add(i);
}
buildDT("DecisionTree", "null", al, ll);
return xmldoc;
}
/**
* 預測/分類函數(利用保留在類裏的xml決策時模型進行預測)
* @param data_pathname 測試集
* @return 預測結果集
*/
public String[] predict(String data_pathname){
read_predictARFF(new File(data_pathname));
preLable=new String[predict_data.size()];
ArrayList<Integer> subset=new ArrayList<Integer>();
for(int i=0;i<predict_data.size();i++){
subset.add(i);
}
Element root=xmldoc.getRootElement();
Element DecisionTree=root.element("DecisionTree");
giveLable(DecisionTree, subset);
return preLable;
}
/**
* 用於計算分類結果的遞歸函數
* @param node 節點
* @param subset 子集(存儲序號)
*/
public void giveLable(Element node, ArrayList<Integer> subset) {
List<Element> list=node.elements();
if (list.size()==0) { //葉子節點
System.out.println("節點:"+node.getName()+"是葉子節點");
String lable=node.getTextTrim();
for(int index:subset ){
preLable[index]=lable;
}
}else{ //非葉子節點
for(Element e:list){
String name=e.getName();
String value=e.attribute("value").getValue();
int index=predictAttribute.indexOf(name);
ArrayList<Integer> temp=new ArrayList<Integer>();
for(int i=0;i<subset.size();i++){ //篩選subset
if(predict_data.get(subset.get(i))[index].equals(value)){
temp.add(subset.get(i));
}
}
giveLable(e, temp);
}
}
}
//讀取arff文件,給attribute、attributevalue、data賦值
public void read_trainARFF(File file) {
try {
FileReader fr = new FileReader(file);
BufferedReader br = new BufferedReader(fr);
String line;
Pattern pattern = Pattern.compile(patternString);
while ((line = br.readLine()) != null) {
Matcher matcher = pattern.matcher(line);
if (matcher.find()) {
trainAttribute.add(matcher.group(1).trim()); //獲取第一個括號裏的內容
//涉及取值,儘量加.trim(),後面也可以看到,即使是換行符也可能會造成字符串不相等
String[] values = matcher.group(2).split(",");
ArrayList<String> al = new ArrayList<String>(values.length);
for (String value : values) {
al.add(value.trim());
}
train_attributeValue.add(al);
} else if (line.startsWith("@data")) {
while ((line = br.readLine()) != null) {
if(line=="")
continue;
String[] row = line.split(",");
train_data.add(row);
}
} else {
continue;
}
}
br.close();
} catch (IOException e1) {
e1.printStackTrace();
}
}
//讀取arff文件,給attribute、attributevalue、data賦值
public void read_predictARFF(File file) {
try {
FileReader fr = new FileReader(file);
BufferedReader br = new BufferedReader(fr);
String line;
Pattern pattern = Pattern.compile(patternString);
while ((line = br.readLine()) != null) {
Matcher matcher = pattern.matcher(line);
if (matcher.find()) {
predictAttribute.add(matcher.group(1).trim()); //獲取第一個括號裏的內容
//涉及取值,儘量加.trim(),後面也可以看到,即使是換行符也可能會造成字符串不相等
String[] values = matcher.group(2).split(",");
ArrayList<String> al = new ArrayList<String>(values.length);
for (String value : values) {
al.add(value.trim());
}
predict_attributeValue.add(al);
} else if (line.startsWith("@data")) {
while ((line = br.readLine()) != null) {
if(line=="")
continue;
String[] row = line.split(",");
predict_data.add(row);
}
} else {
continue;
}
}
br.close();
} catch (IOException e1) {
e1.printStackTrace();
}
}
//設置決策變量
public void setDec(int n) {
if (n < 0 || n >= trainAttribute.size()) {
System.err.println("決策變量指定錯誤。");
System.exit(2);
}
decatt = n;
}
public void setDec(String name) {
int n = trainAttribute.indexOf(name);
setDec(n);
}
//給一個樣本(數組中是各種情況的計數),計算它的熵
public double getEntropy(int[] arr) {
double entropy = 0.0;
int sum = 0;
for (int i = 0; i < arr.length; i++) { //關於Double.MIN_VALUE好像和浮點精度有關,不是很懂
entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
sum += arr[i];
}
entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
entropy /= sum;
return entropy;
}
//給一個樣本數組及樣本的算術和,計算它的熵
public double getEntropy(int[] arr, int sum) {
double entropy = 0.0;
for (int i = 0; i < arr.length; i++) {
entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
}
entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
entropy /= sum;
return entropy;
}
//判斷類標是否統一,統一則之後即爲葉節點(也可以設置爲類別比例達到某一程度等其他指標)
public boolean infoPure(ArrayList<Integer> subset) {
String value = train_data.get(subset.get(0))[decatt];
for (int i = 1; i < subset.size(); i++) {
String next=train_data.get(subset.get(i))[decatt];
if (!value.trim().equals(next.trim()))
return false;
}
return true;
}
// 給定原始數據的子集(subset中存儲行號),當以第index個屬性爲節點時計算它的信息熵
public double calNodeEntropy(ArrayList<Integer> subset, int index) {
int sum = subset.size();
//System.out.println("sum="+sum);
//System.out.println("index="+index);
double entropy = 0.0;
int[][] info = new int[train_attributeValue.get(index).size()][];
for (int i = 0; i < info.length; i++)
info[i] = new int[train_attributeValue.get(decatt).size()];
int[] count = new int[train_attributeValue.get(index).size()];
for (int i = 0; i < sum; i++) {
int n = subset.get(i);
String nodevalue = train_data.get(n)[index];
int nodeind = train_attributeValue.get(index).indexOf(nodevalue);
count[nodeind]++;
String decvalue = train_data.get(n)[decatt];
//System.out.println(attributevalue.get(decatt).indexOf("no"));
int decind = train_attributeValue.get(decatt).indexOf(decvalue.trim());
info[nodeind][decind]++;
}
for (int i = 0; i < info.length; i++) {
entropy += getEntropy(info[i]) * count[i] / sum;
}
return entropy;
}
/**
* 構建決策樹 (核心函數)
* @param node 節點名稱
* @param value 節點值
* @param subset 數據子集
* @param selatt 屬性子集
*/
public void buildDT(String node, String value, ArrayList<Integer> subset,
LinkedList<Integer> selatt) {
Element ele = null;
@SuppressWarnings("unchecked")
List<Element> list = root.selectNodes("//"+node);
Iterator<Element> iter=list.iterator();
while(iter.hasNext()){
ele=iter.next();
if(ele.attributeValue("value").equals(value))
break;
}
if (infoPure(subset)) {
ele.setText(train_data.get(subset.get(0))[decatt]); //類標單一,直接寫分類
return;
}
int minIndex = -1;
double minEntropy = Double.MAX_VALUE;
for (int i = 0; i < selatt.size(); i++) {
if (i == decatt)
continue;
double entropy = calNodeEntropy(subset, selatt.get(i));
if (entropy < minEntropy) {
minIndex = selatt.get(i);
minEntropy = entropy;
}
}
String nodeName= trainAttribute.get(minIndex);
selatt.remove(new Integer(minIndex));
ArrayList<String> attvalues = train_attributeValue.get(minIndex);
for (String val : attvalues) {
//System.out.println(nodeName+"="+val);
ele.addElement(nodeName).addAttribute("value", val);
ArrayList<Integer> al = new ArrayList<Integer>();
for (int i = 0; i < subset.size(); i++) {
if (train_data.get(subset.get(i))[minIndex].equals(val)) {
al.add(subset.get(i));
}
}
buildDT(nodeName, val, al, selatt);
}
}
/**
* 把xml寫入文件
* @param filename
*/
public void writeXML(String filename) {
try {
File file = new File(filename);
if (!file.exists())
file.createNewFile();
FileWriter fw = new FileWriter(file);
OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式
XMLWriter output = new XMLWriter(fw, format);
output.write(xmldoc);
output.close();
} catch (IOException e) {
System.out.println(e.getMessage());
}
}
}
主函數
package ID3;
public class Main {
public static void main(String[] args) {
ID3 inst=new ID3();
inst.train("play", "files/ID3/train.arff");
inst.writeXML("files/ID3/ID3_Tree.xml");
String[] preLable=inst.predict("files/ID3/predict.arff");
for(int i=0;i<preLable.length;i++){
System.out.println(i+preLable[i]);
}
}
}
決策樹xml文件
<?xml version="1.0" encoding="UTF-8"?>
<root>
<DecisionTree value="null">
<outlook value="sunny">
<humidity value="high">no</humidity>
<humidity value="normal">yes</humidity>
</outlook>
<outlook value="overcast">yes</outlook>
<outlook value="rainy">
<windy value="TRUE">no</windy>
<windy value="FALSE">yes</windy>
</outlook>
</DecisionTree>
</root>