ID3算法詳解

歡迎到我的個人域名博客:http://zhoulingyu.com


這學期hadoop的作業是實現ID3算法,在網上找到了一篇非常好的資料,但是代碼沒有詳細介紹。研究一番之後寫出了自己ID3算法的普通實現和線程模擬分佈式實現。

先附上原文地址:
http://www.cnblogs.com/zhangchaoyang/articles/2196631.html
博主寫的非常好,基本上簡單易懂的描述了ID3算法的原理。

問題

統計了14天的氣象數據(指標包括outlook,temperature,humidity,windy),並已知這些天氣是否打球(play)。如果給出新一天的氣象指標數據:sunny,cool,high,TRUE,判斷一下會不會去打球。

ID3原理

直接看博客鏈接,非常詳細http://www.cnblogs.com/zhangchaoyang/articles/2196631.html

代碼實現

別忘了放入輸入文件,按我的程序保存在指定位置,保存名爲weather.arff

@relation weather.symbolic

@attribute outlook {sunny, overcast, rainy}
@attribute temperature {hot, mild, cool}
@attribute humidity {high, normal}
@attribute windy {TRUE, FALSE}
@attribute play {yes, no}

@data
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no
package com.coderfish.id3;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.dom4j.Document;
import org.dom4j.DocumentHelper;
import org.dom4j.Element;
import org.dom4j.io.OutputFormat;
import org.dom4j.io.XMLWriter;

/**
 * date:2015年6月15日<br/ >
 * description:
 * 統計了14天的氣象數據(outlook天氣狀況,temperature溫度,humidity溼度,windy風力),並已知這些天氣是否打球(play)
 * 如果給出新一天的氣象指標數據:sunny,cool,high,TRUE,判斷一下會不會去打球。<br/ >
 * 
 * @author 周凌宇 <br/ >
 * 
 */
public class ID3 {
    /**
     * 存儲屬性的名稱
     */
    private ArrayList<String> attribute = new ArrayList<String>();
    /**
     * 存儲每個屬性的取值
     */
    private ArrayList<ArrayList<String>> attributeValue = new ArrayList<ArrayList<String>>();
    /**
     * 原始數據
     */
    private ArrayList<String[]> data = new ArrayList<String[]>();
    /**
     * 決策變量在屬性集中的索引
     */
    int decisionIndex;
    /**
     * ARFF格式
     */
    public static final String patternString = "@attribute(.*)[{](.*?)[}]";
    /**
     * XML文件對象
     */
    Document xmldoc;
    /**
     * 根節點
     */
    Element root;

    /**
     * 構造方法<br/ >
     * 新建xml<br/ >
     * <root><br/ >
     * <DecisionTree value="null"><br/ >
     * </DecisionTree><br/ >
     * </root>
     */
    public ID3() {
        xmldoc = DocumentHelper.createDocument();
        root = xmldoc.addElement("root");
        root.addElement("DecisionTree").addAttribute("value", "null");
    }

    public static void main(String[] args) {
        ID3 inst = new ID3();
        inst.readARFF(new File("f:\\id3temp\\weather.arff"));
        inst.setDec("play");
        // conditionList存放條件個數
        LinkedList<Integer> conditionNumList = new LinkedList<Integer>();
        for (int i = 0; i < inst.attribute.size(); i++) {
            // 如果i!=決策變量索引(實際5)
            if (i != inst.decisionIndex)
                conditionNumList.add(i);
        }
        // dataList存放data的個數
        ArrayList<Integer> dataNumList = new ArrayList<Integer>();
        for (int i = 0; i < inst.data.size(); i++) {
            dataNumList.add(i);
        }
        inst.buildDT("DecisionTree", "null", dataNumList, conditionNumList);
        inst.writeXML("f:\\id3temp\\dt.xml");
        return;
    }

    /**
     * 讀取arff文件,給attribute、attributevalue、data賦值
     * 
     * @param file
     *            要讀取的文件
     */
    public void readARFF(File file) {
        try {
            FileReader fr = new FileReader(file);
            BufferedReader br = new BufferedReader(fr);
            String line;
            // 設置arff模式
            Pattern pattern = Pattern.compile(patternString);
            // 讀取數據
            while ((line = br.readLine()) != null) {
                // 判斷合規
                Matcher matcher = pattern.matcher(line);
                // 嘗試查找與該模式匹配的輸入序列的下一個子序列
                if (matcher.find()) {
                    // 返回匹配到的子字符串
                    attribute.add(matcher.group(1).trim());
                    String[] values = matcher.group(2).split(",");
                    ArrayList<String> al = new ArrayList<String>(values.length);
                    for (String value : values) {
                        al.add(value.trim());
                    }
                    // 存入attributevalue對象中
                    attributeValue.add(al);
                }
                // 判斷開頭
                else if (line.startsWith("@data")) {
                    while ((line = br.readLine()) != null) {
                        if (line == "")
                            continue;
                        String[] row = line.split(",");
                        data.add(row);
                    }
                } else {
                    continue;
                }
            }
            br.close();
        } catch (IOException e1) {
            e1.printStackTrace();
        }
    }

    /**
     * 設置決策變量索引
     * 
     * @param n
     */
    public void setDec(int n) {
        if (n < 0 || n >= attribute.size()) {
            System.err.println("決策變量指定錯誤。");
            System.exit(2);
        }
        decisionIndex = n;
    }

    /**
     * 設置決策變量
     * 
     * @param name
     */
    public void setDec(String name) {
        // 獲取下標
        int n = attribute.indexOf(name);
        setDec(n);
    }

    /**
     * 計算熵
     * 
     * @param arr
     *            yes or no個數
     * @return 熵
     */
    public double getEntropy(int[] arr) {
        double entropy = 0.0;
        int sum = 0;
        for (int i = 0; i < arr.length; i++) {
            entropy -= arr[i] * Math.log(arr[i] + Double.MIN_VALUE)
                    / Math.log(2);
            sum += arr[i];
        }
        // log(N)/log(2)
        entropy += sum * Math.log(sum + Double.MIN_VALUE) / Math.log(2);
        entropy /= sum;
        return entropy;
    }

    /**
     * 判斷節點值是否全部相同
     * 
     * @param subset
     *            給定子集
     * @return 全部相同返回true,有不相同返回false
     */
    public boolean infoPure(ArrayList<Integer> subset) {
        // yes or no
        String value = data.get(subset.get(0))[decisionIndex];
        for (int i = 1; i < subset.size(); i++) {
            String next = data.get(subset.get(i))[decisionIndex];
            // 如果該結點值和第一個節點值不同
            if (!value.equals(next))
                return false;
        }
        return true;
    }

    /**
     * 給定原始數據的子集(subset中存儲行號),當以第index個屬性爲節點時計算它的信息熵
     * 
     * @param subset
     *            給定子集
     * @param index
     *            屬性下標
     * @return
     */
    public double calNodeEntropy(ArrayList<Integer> subset, int index) {
        int sum = subset.size();
        double entropy = 0.0;
        // 長度爲屬性個數 存放每種可能的決策情況
        int[][] info = new int[attributeValue.get(index).size()][];
        for (int i = 0; i < info.length; i++)
            // 決策個數
            info[i] = new int[attributeValue.get(decisionIndex).size()];
        // 長度爲屬性個數
        int[] count = new int[attributeValue.get(index).size()];
        for (int i = 0; i < sum; i++) {
            // 子集遍歷下標
            int n = subset.get(i);
            String nodeValue = data.get(n)[index];
            int nodeIndex = attributeValue.get(index).indexOf(nodeValue);
            count[nodeIndex]++;
            // 決策結果
            String decValue = data.get(n)[decisionIndex];
            int decIndex = attributeValue.get(decisionIndex).indexOf(decValue);
            info[nodeIndex][decIndex]++;
        }
        for (int i = 0; i < info.length; i++) {
            // 加權求值
            entropy += getEntropy(info[i]) * count[i] / sum;
        }
        return entropy;
    }

    /**
     * 構建決策樹 第一次給定值爲"DecisionTree", "null", al, ll
     * 
     * @param name
     *            節點名稱
     * @param value
     *            節點值
     * @param subset
     *            子集
     * @param conditionList
     */
    public void buildDT(String name, String value,
            ArrayList<Integer> dataNumList, LinkedList<Integer> conditionNumList) {
        Element ele = null;
        // 從任意位置的節點上選擇名稱爲 name 的節點。
        List<Element> list = root.selectNodes("//" + name);
        Iterator<Element> iter = list.iterator();
        while (iter.hasNext()) {
            ele = iter.next();
            // 節點匹配value跳出循環
            if (ele.attributeValue("value").equals(value))
                break;
        }
        // 如果子決策全部相同 直接放入決策 跳出遞歸
        if (infoPure(dataNumList)) {
            ele.setText(data.get(dataNumList.get(0))[decisionIndex]);
            return;
        }
        int minIndex = -1;
        double minEntropy = Double.MAX_VALUE;
        for (int i = 0; i < conditionNumList.size(); i++) {
            if (i == decisionIndex)
                continue;
            // 給定子集計算熵
            double entropy = calNodeEntropy(dataNumList,
                    conditionNumList.get(i));
            if (entropy < minEntropy) {
                minIndex = conditionNumList.get(i);
                minEntropy = entropy;
            }
        }
        // 找出最小熵的節點屬性名(作爲要加入節點名)
        String nodeName = attribute.get(minIndex);
        conditionNumList.remove(new Integer(minIndex));
        // 拿出對應屬性的所有值
        ArrayList<String> attvalues = attributeValue.get(minIndex);
        for (String val : attvalues) {
            ele.addElement(nodeName).addAttribute("value", val);
            ArrayList<Integer> al = new ArrayList<Integer>();
            for (int i = 0; i < dataNumList.size(); i++) {
                // 找出開頭爲val的數據行
                if (data.get(dataNumList.get(i))[minIndex].equals(val)) {
                    al.add(dataNumList.get(i));
                }
            }
            buildDT(nodeName, val, al, conditionNumList);
        }
    }

    /**
     * xml輸出
     * 
     * @param filename
     */
    public void writeXML(String filename) {
        try {
            File file = new File(filename);
            if (!file.exists())
                file.createNewFile();
            FileWriter fw = new FileWriter(file);
            OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式
            XMLWriter output = new XMLWriter(fw, format);
            output.write(xmldoc);
            output.close();
        } catch (IOException e) {
            System.out.println(e.getMessage());
        }
    }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章