ID3決策樹(Java實現)

說明

參考文章-歸納決策樹ID3(Java實現),完成代碼編寫。
在原代碼的基礎上補充了預測函數,實現利用模型對新數據進行分類預測。
作者對ID3決策樹的介紹-ID3決策樹

決策樹採用xml文件保存,使用Dom4J類庫,點擊下載
讓Dom4J支持按XPath選擇節點,還得引入包jaxen.jar,點擊下載
源代碼彙總,點擊下載

思路

這裏寫圖片描述

代碼

輸入文件採用ARFF格式,使用的訓練數據文件如下:
train.arff

@relation weather.symbolic 
@attribute outlook {sunny,overcast,rainy} 
@attribute temperature {hot,mild,cool} 
@attribute humidity {high,normal} 
@attribute windy {TRUE,FALSE} 
@attribute play {yes,no} 

@data 
sunny,hot,high,FALSE,no 
sunny,hot,high,TRUE,no 
overcast,hot,high,FALSE,yes 
rainy,mild,high,FALSE,yes 
rainy,cool,normal,FALSE,yes 
rainy,cool,normal,TRUE,no 
overcast,cool,normal,TRUE,yes 
sunny,mild,high,FALSE,no 
sunny,cool,normal,FALSE,yes 
rainy,mild,normal,FALSE,yes 
sunny,mild,normal,TRUE,yes 
overcast,mild,high,TRUE,yes 
overcast,hot,normal,FALSE,yes 
rainy,mild,high,TRUE,no

ARFF(Attribute-Relation File Format):格式簡單明瞭,分爲兩部分,第一部分交代屬性及取值範圍,第二部分則是數據部分(data)。
由於只是測試代碼效果,測試集(predict.arff)也是上述數據,只是將類標相關的數據移除了。

ID3類

package ID3;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.Character.Subset;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.dom4j.Document;
import org.dom4j.DocumentHelper;
import org.dom4j.Element;
import org.dom4j.io.OutputFormat;
import org.dom4j.io.XMLWriter;
import org.w3c.dom.NodeList;

public class ID3 {
    // 同時保留訓練集和測試集的數據在模型中,防止訓練集和測試集的列順序不同
    private ArrayList<String> trainAttribute = new ArrayList<String>(); // 存儲訓練集屬性的名稱 
    private ArrayList<ArrayList<String>> train_attributeValue = new ArrayList<ArrayList<String>>(); // 存儲訓練集每個屬性的取值 
    private ArrayList<String> predictAttribute = new ArrayList<String>(); // 存儲測試集屬性的名稱 
    private ArrayList<ArrayList<String>> predict_attributeValue = new ArrayList<ArrayList<String>>(); // 存儲測試集每個屬性的取值 


    private ArrayList<String[]> train_data = new ArrayList<String[]>(); // 訓練集數據 ,即arff文件中的data字符串
    private ArrayList<String[]> predict_data = new ArrayList<String[]>(); // 測試集數據

    private String[] preLable;

    int decatt; // 決策變量在屬性集中的索引(即類標所在列) 
    public static final String patternString = "@attribute(.*)[{](.*?)[}]"; 
    //正則表達,其中*? 表示重複任意次,但儘可能少重複,防止匹配到更後面的"}"符號

    Document xmldoc; 
    Element root; 

    public ID3() { 
        //創建並初始化xml文件,以用於儲存決策樹結構
        xmldoc = DocumentHelper.createDocument(); 
        root = xmldoc.addElement("root"); 
        root.addElement("DecisionTree").addAttribute("value", "null"); 
    } 
    /**
     * 模型訓練函數
     * @param class_name  類標變量
     * @param data_pathname 訓練集
     * @return xml決策樹文件
     */
    public Document train(String class_name,String data_pathname){
        read_trainARFF(new File(data_pathname)); 
        setDec(class_name);
        LinkedList<Integer> ll=new LinkedList<Integer>(); //LinkList用於增刪比ArrayList有優勢
        for(int i=0;i<trainAttribute.size();i++){ 
            if(i!=decatt) ll.add(i);  //防止類別變量不在最後一列發生錯誤 
        } 

        ArrayList<Integer> al=new ArrayList<Integer>(); 
        for(int i=0;i<train_data.size();i++){ 
            al.add(i); 
        }
        buildDT("DecisionTree", "null", al, ll);
        return xmldoc;
    }

    /**
     * 預測/分類函數(利用保留在類裏的xml決策時模型進行預測)
     * @param data_pathname  測試集
     * @return 預測結果集
     */
    public String[] predict(String data_pathname){
        read_predictARFF(new File(data_pathname)); 
        preLable=new String[predict_data.size()];

        ArrayList<Integer> subset=new ArrayList<Integer>();

        for(int i=0;i<predict_data.size();i++){
            subset.add(i);
        }

        Element root=xmldoc.getRootElement();
        Element DecisionTree=root.element("DecisionTree");

        giveLable(DecisionTree, subset);
        return preLable;

    }
    /**
     * 用於計算分類結果的遞歸函數
     * @param node 節點
     * @param subset 子集(存儲序號)
     */
    public void giveLable(Element node, ArrayList<Integer> subset) {
        List<Element> list=node.elements();

        if (list.size()==0) {   //葉子節點
            System.out.println("節點:"+node.getName()+"是葉子節點");
            String lable=node.getTextTrim();
            for(int index:subset ){
                preLable[index]=lable;
            }
        }else{  //非葉子節點
            for(Element e:list){
                String name=e.getName();
                String value=e.attribute("value").getValue();
                int index=predictAttribute.indexOf(name);
                ArrayList<Integer> temp=new ArrayList<Integer>();
                for(int i=0;i<subset.size();i++){  //篩選subset
                    if(predict_data.get(subset.get(i))[index].equals(value)){
                        temp.add(subset.get(i));
                    }
                }
                giveLable(e, temp);
            }
        }
    }

    //讀取arff文件,給attribute、attributevalue、data賦值 
    public void read_trainARFF(File file) { 
        try { 
            FileReader fr = new FileReader(file); 
            BufferedReader br = new BufferedReader(fr); 
            String line; 
            Pattern pattern = Pattern.compile(patternString); 
            while ((line = br.readLine()) != null) { 
                Matcher matcher = pattern.matcher(line); 
                if (matcher.find()) { 
                    trainAttribute.add(matcher.group(1).trim()); //獲取第一個括號裏的內容
                    //涉及取值,儘量加.trim(),後面也可以看到,即使是換行符也可能會造成字符串不相等
                    String[] values = matcher.group(2).split(","); 
                    ArrayList<String> al = new ArrayList<String>(values.length); 
                    for (String value : values) { 
                        al.add(value.trim()); 
                    } 
                    train_attributeValue.add(al); 
                } else if (line.startsWith("@data")) { 
                    while ((line = br.readLine()) != null) { 
                        if(line=="") 
                            continue; 
                        String[] row = line.split(","); 
                        train_data.add(row); 
                    } 
                } else { 
                    continue; 
                } 
            } 
            br.close(); 
        } catch (IOException e1) { 
            e1.printStackTrace(); 
        } 
    } 

    //讀取arff文件,給attribute、attributevalue、data賦值 
    public void read_predictARFF(File file) { 
        try { 
            FileReader fr = new FileReader(file); 
            BufferedReader br = new BufferedReader(fr); 
            String line; 
            Pattern pattern = Pattern.compile(patternString); 
            while ((line = br.readLine()) != null) { 
                Matcher matcher = pattern.matcher(line); 
                if (matcher.find()) { 
                    predictAttribute.add(matcher.group(1).trim()); //獲取第一個括號裏的內容
                    //涉及取值,儘量加.trim(),後面也可以看到,即使是換行符也可能會造成字符串不相等
                    String[] values = matcher.group(2).split(","); 
                    ArrayList<String> al = new ArrayList<String>(values.length); 
                    for (String value : values) { 
                        al.add(value.trim()); 
                    } 
                    predict_attributeValue.add(al); 
                } else if (line.startsWith("@data")) { 
                    while ((line = br.readLine()) != null) { 
                        if(line=="") 
                            continue; 
                        String[] row = line.split(","); 
                        predict_data.add(row); 
                    } 
                } else { 
                    continue; 
                } 
            } 
            br.close(); 
        } catch (IOException e1) { 
            e1.printStackTrace(); 
        } 
    } 

    //設置決策變量 
    public void setDec(int n) { 
        if (n < 0 || n >= trainAttribute.size()) { 
            System.err.println("決策變量指定錯誤。"); 
            System.exit(2); 
        } 
        decatt = n; 
    } 
    public void setDec(String name) { 
        int n = trainAttribute.indexOf(name); 
        setDec(n); 
    }   

    //給一個樣本(數組中是各種情況的計數),計算它的熵 
    public double getEntropy(int[] arr) { 
        double entropy = 0.0; 
        int sum = 0; 
        for (int i = 0; i < arr.length; i++) { //關於Double.MIN_VALUE好像和浮點精度有關,不是很懂
            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2); 
            sum += arr[i]; 
        } 
        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2); 
        entropy /= sum; 
        return entropy; 
    } 
    //給一個樣本數組及樣本的算術和,計算它的熵 
    public double getEntropy(int[] arr, int sum) { 
        double entropy = 0.0; 
        for (int i = 0; i < arr.length; i++) { 
            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2); 
        } 
        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2); 
        entropy /= sum; 
        return entropy; 
    } 

    //判斷類標是否統一,統一則之後即爲葉節點(也可以設置爲類別比例達到某一程度等其他指標)
    public boolean infoPure(ArrayList<Integer> subset) { 
        String value = train_data.get(subset.get(0))[decatt]; 
        for (int i = 1; i < subset.size(); i++) { 
            String next=train_data.get(subset.get(i))[decatt]; 
            if (!value.trim().equals(next.trim())) 
                return false; 
        } 
        return true; 
    } 

    // 給定原始數據的子集(subset中存儲行號),當以第index個屬性爲節點時計算它的信息熵 
    public double calNodeEntropy(ArrayList<Integer> subset, int index) { 
        int sum = subset.size(); 
        //System.out.println("sum="+sum);
        //System.out.println("index="+index);
        double entropy = 0.0; 
        int[][] info = new int[train_attributeValue.get(index).size()][]; 
        for (int i = 0; i < info.length; i++) 
            info[i] = new int[train_attributeValue.get(decatt).size()]; 
        int[] count = new int[train_attributeValue.get(index).size()]; 
        for (int i = 0; i < sum; i++) { 
            int n = subset.get(i); 
            String nodevalue = train_data.get(n)[index]; 
            int nodeind = train_attributeValue.get(index).indexOf(nodevalue); 
            count[nodeind]++; 
            String decvalue = train_data.get(n)[decatt]; 
            //System.out.println(attributevalue.get(decatt).indexOf("no"));
            int decind = train_attributeValue.get(decatt).indexOf(decvalue.trim()); 

            info[nodeind][decind]++; 
        } 
        for (int i = 0; i < info.length; i++) { 
            entropy += getEntropy(info[i]) * count[i] / sum; 
        } 
        return entropy; 
    } 


    /**
     * 構建決策樹 (核心函數)
     * @param node  節點名稱
     * @param value 節點值 
     * @param subset 數據子集
     * @param selatt 屬性子集
     */
    public void buildDT(String node, String value, ArrayList<Integer> subset, 
            LinkedList<Integer> selatt) { 
        Element ele = null; 
        @SuppressWarnings("unchecked") 
        List<Element> list = root.selectNodes("//"+node); 
        Iterator<Element> iter=list.iterator(); 
        while(iter.hasNext()){ 
            ele=iter.next(); 
            if(ele.attributeValue("value").equals(value)) 
                break; 
        } 
        if (infoPure(subset)) { 
            ele.setText(train_data.get(subset.get(0))[decatt]); //類標單一,直接寫分類
            return; 
        } 
        int minIndex = -1; 
        double minEntropy = Double.MAX_VALUE; 
        for (int i = 0; i < selatt.size(); i++) { 
            if (i == decatt) 
                continue;

            double entropy = calNodeEntropy(subset, selatt.get(i)); 
            if (entropy < minEntropy) { 
                minIndex = selatt.get(i); 
                minEntropy = entropy; 
            } 
        } 
        String nodeName= trainAttribute.get(minIndex); 
        selatt.remove(new Integer(minIndex)); 
        ArrayList<String> attvalues = train_attributeValue.get(minIndex); 
        for (String val : attvalues) { 
            //System.out.println(nodeName+"="+val);
            ele.addElement(nodeName).addAttribute("value", val); 
            ArrayList<Integer> al = new ArrayList<Integer>(); 
            for (int i = 0; i < subset.size(); i++) { 
                if (train_data.get(subset.get(i))[minIndex].equals(val)) { 
                    al.add(subset.get(i)); 
                } 
            } 
            buildDT(nodeName, val, al, selatt); 
        } 
    } 


    /**
     * 把xml寫入文件 
     * @param filename
     */
    public void writeXML(String filename) { 
        try { 
            File file = new File(filename); 
            if (!file.exists()) 
                file.createNewFile(); 
            FileWriter fw = new FileWriter(file); 
            OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式 
            XMLWriter output = new XMLWriter(fw, format); 
            output.write(xmldoc); 
            output.close(); 
        } catch (IOException e) { 
            System.out.println(e.getMessage()); 
        } 
    } 
}

主函數

package ID3;


public class Main {
    public static void main(String[] args) {
        ID3 inst=new ID3();
        inst.train("play", "files/ID3/train.arff");
        inst.writeXML("files/ID3/ID3_Tree.xml"); 
        String[] preLable=inst.predict("files/ID3/predict.arff");
        for(int i=0;i<preLable.length;i++){
            System.out.println(i+preLable[i]);
        }   
    }
}

決策樹xml文件

<?xml version="1.0" encoding="UTF-8"?>

<root>
  <DecisionTree value="null">
    <outlook value="sunny">
      <humidity value="high">no</humidity>
      <humidity value="normal">yes</humidity>
    </outlook>
    <outlook value="overcast">yes</outlook>
    <outlook value="rainy">
      <windy value="TRUE">no</windy>
      <windy value="FALSE">yes</windy>
    </outlook>
  </DecisionTree>
</root>
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章