歡迎到我的個人域名博客:http://zhoulingyu.com
這學期hadoop的作業是實現ID3算法,在網上找到了一篇非常好的資料,但是代碼沒有詳細介紹。研究一番之後寫出了自己ID3算法的普通實現和線程模擬分佈式實現。
先附上原文地址:
http://www.cnblogs.com/zhangchaoyang/articles/2196631.html
博主寫的非常好,基本上簡單易懂的描述了ID3算法的原理。
問題
統計了14天的氣象數據(指標包括outlook,temperature,humidity,windy),並已知這些天氣是否打球(play)。如果給出新一天的氣象指標數據:sunny,cool,high,TRUE,判斷一下會不會去打球。
ID3原理
直接看博客鏈接,非常詳細http://www.cnblogs.com/zhangchaoyang/articles/2196631.html
代碼實現
別忘了放入輸入文件,按我的程序保存在指定位置,保存名爲weather.arff
@relation weather.symbolic
@attribute outlook {sunny, overcast, rainy}
@attribute temperature {hot, mild, cool}
@attribute humidity {high, normal}
@attribute windy {TRUE, FALSE}
@attribute play {yes, no}
@data
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no
package com.coderfish.id3;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.dom4j.Document;
import org.dom4j.DocumentHelper;
import org.dom4j.Element;
import org.dom4j.io.OutputFormat;
import org.dom4j.io.XMLWriter;
/**
* date:2015年6月15日<br/ >
* description:
* 統計了14天的氣象數據(outlook天氣狀況,temperature溫度,humidity溼度,windy風力),並已知這些天氣是否打球(play)
* 如果給出新一天的氣象指標數據:sunny,cool,high,TRUE,判斷一下會不會去打球。<br/ >
*
* @author 周凌宇 <br/ >
*
*/
public class ID3 {
/**
* 存儲屬性的名稱
*/
private ArrayList<String> attribute = new ArrayList<String>();
/**
* 存儲每個屬性的取值
*/
private ArrayList<ArrayList<String>> attributeValue = new ArrayList<ArrayList<String>>();
/**
* 原始數據
*/
private ArrayList<String[]> data = new ArrayList<String[]>();
/**
* 決策變量在屬性集中的索引
*/
int decisionIndex;
/**
* ARFF格式
*/
public static final String patternString = "@attribute(.*)[{](.*?)[}]";
/**
* XML文件對象
*/
Document xmldoc;
/**
* 根節點
*/
Element root;
/**
* 構造方法<br/ >
* 新建xml<br/ >
* <root><br/ >
* <DecisionTree value="null"><br/ >
* </DecisionTree><br/ >
* </root>
*/
public ID3() {
xmldoc = DocumentHelper.createDocument();
root = xmldoc.addElement("root");
root.addElement("DecisionTree").addAttribute("value", "null");
}
public static void main(String[] args) {
ID3 inst = new ID3();
inst.readARFF(new File("f:\\id3temp\\weather.arff"));
inst.setDec("play");
// conditionList存放條件個數
LinkedList<Integer> conditionNumList = new LinkedList<Integer>();
for (int i = 0; i < inst.attribute.size(); i++) {
// 如果i!=決策變量索引(實際5)
if (i != inst.decisionIndex)
conditionNumList.add(i);
}
// dataList存放data的個數
ArrayList<Integer> dataNumList = new ArrayList<Integer>();
for (int i = 0; i < inst.data.size(); i++) {
dataNumList.add(i);
}
inst.buildDT("DecisionTree", "null", dataNumList, conditionNumList);
inst.writeXML("f:\\id3temp\\dt.xml");
return;
}
/**
* 讀取arff文件,給attribute、attributevalue、data賦值
*
* @param file
* 要讀取的文件
*/
public void readARFF(File file) {
try {
FileReader fr = new FileReader(file);
BufferedReader br = new BufferedReader(fr);
String line;
// 設置arff模式
Pattern pattern = Pattern.compile(patternString);
// 讀取數據
while ((line = br.readLine()) != null) {
// 判斷合規
Matcher matcher = pattern.matcher(line);
// 嘗試查找與該模式匹配的輸入序列的下一個子序列
if (matcher.find()) {
// 返回匹配到的子字符串
attribute.add(matcher.group(1).trim());
String[] values = matcher.group(2).split(",");
ArrayList<String> al = new ArrayList<String>(values.length);
for (String value : values) {
al.add(value.trim());
}
// 存入attributevalue對象中
attributeValue.add(al);
}
// 判斷開頭
else if (line.startsWith("@data")) {
while ((line = br.readLine()) != null) {
if (line == "")
continue;
String[] row = line.split(",");
data.add(row);
}
} else {
continue;
}
}
br.close();
} catch (IOException e1) {
e1.printStackTrace();
}
}
/**
* 設置決策變量索引
*
* @param n
*/
public void setDec(int n) {
if (n < 0 || n >= attribute.size()) {
System.err.println("決策變量指定錯誤。");
System.exit(2);
}
decisionIndex = n;
}
/**
* 設置決策變量
*
* @param name
*/
public void setDec(String name) {
// 獲取下標
int n = attribute.indexOf(name);
setDec(n);
}
/**
* 計算熵
*
* @param arr
* yes or no個數
* @return 熵
*/
public double getEntropy(int[] arr) {
double entropy = 0.0;
int sum = 0;
for (int i = 0; i < arr.length; i++) {
entropy -= arr[i] * Math.log(arr[i] + Double.MIN_VALUE)
/ Math.log(2);
sum += arr[i];
}
// log(N)/log(2)
entropy += sum * Math.log(sum + Double.MIN_VALUE) / Math.log(2);
entropy /= sum;
return entropy;
}
/**
* 判斷節點值是否全部相同
*
* @param subset
* 給定子集
* @return 全部相同返回true,有不相同返回false
*/
public boolean infoPure(ArrayList<Integer> subset) {
// yes or no
String value = data.get(subset.get(0))[decisionIndex];
for (int i = 1; i < subset.size(); i++) {
String next = data.get(subset.get(i))[decisionIndex];
// 如果該結點值和第一個節點值不同
if (!value.equals(next))
return false;
}
return true;
}
/**
* 給定原始數據的子集(subset中存儲行號),當以第index個屬性爲節點時計算它的信息熵
*
* @param subset
* 給定子集
* @param index
* 屬性下標
* @return
*/
public double calNodeEntropy(ArrayList<Integer> subset, int index) {
int sum = subset.size();
double entropy = 0.0;
// 長度爲屬性個數 存放每種可能的決策情況
int[][] info = new int[attributeValue.get(index).size()][];
for (int i = 0; i < info.length; i++)
// 決策個數
info[i] = new int[attributeValue.get(decisionIndex).size()];
// 長度爲屬性個數
int[] count = new int[attributeValue.get(index).size()];
for (int i = 0; i < sum; i++) {
// 子集遍歷下標
int n = subset.get(i);
String nodeValue = data.get(n)[index];
int nodeIndex = attributeValue.get(index).indexOf(nodeValue);
count[nodeIndex]++;
// 決策結果
String decValue = data.get(n)[decisionIndex];
int decIndex = attributeValue.get(decisionIndex).indexOf(decValue);
info[nodeIndex][decIndex]++;
}
for (int i = 0; i < info.length; i++) {
// 加權求值
entropy += getEntropy(info[i]) * count[i] / sum;
}
return entropy;
}
/**
* 構建決策樹 第一次給定值爲"DecisionTree", "null", al, ll
*
* @param name
* 節點名稱
* @param value
* 節點值
* @param subset
* 子集
* @param conditionList
*/
public void buildDT(String name, String value,
ArrayList<Integer> dataNumList, LinkedList<Integer> conditionNumList) {
Element ele = null;
// 從任意位置的節點上選擇名稱爲 name 的節點。
List<Element> list = root.selectNodes("//" + name);
Iterator<Element> iter = list.iterator();
while (iter.hasNext()) {
ele = iter.next();
// 節點匹配value跳出循環
if (ele.attributeValue("value").equals(value))
break;
}
// 如果子決策全部相同 直接放入決策 跳出遞歸
if (infoPure(dataNumList)) {
ele.setText(data.get(dataNumList.get(0))[decisionIndex]);
return;
}
int minIndex = -1;
double minEntropy = Double.MAX_VALUE;
for (int i = 0; i < conditionNumList.size(); i++) {
if (i == decisionIndex)
continue;
// 給定子集計算熵
double entropy = calNodeEntropy(dataNumList,
conditionNumList.get(i));
if (entropy < minEntropy) {
minIndex = conditionNumList.get(i);
minEntropy = entropy;
}
}
// 找出最小熵的節點屬性名(作爲要加入節點名)
String nodeName = attribute.get(minIndex);
conditionNumList.remove(new Integer(minIndex));
// 拿出對應屬性的所有值
ArrayList<String> attvalues = attributeValue.get(minIndex);
for (String val : attvalues) {
ele.addElement(nodeName).addAttribute("value", val);
ArrayList<Integer> al = new ArrayList<Integer>();
for (int i = 0; i < dataNumList.size(); i++) {
// 找出開頭爲val的數據行
if (data.get(dataNumList.get(i))[minIndex].equals(val)) {
al.add(dataNumList.get(i));
}
}
buildDT(nodeName, val, al, conditionNumList);
}
}
/**
* xml輸出
*
* @param filename
*/
public void writeXML(String filename) {
try {
File file = new File(filename);
if (!file.exists())
file.createNewFile();
FileWriter fw = new FileWriter(file);
OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式
XMLWriter output = new XMLWriter(fw, format);
output.write(xmldoc);
output.close();
} catch (IOException e) {
System.out.println(e.getMessage());
}
}
}