ID3算法

先上問題吧,我們統計了14天的氣象數據(指標包括outlook,temperature,humidity,windy),並已知這些天氣是否打球(play)。如果給出新一天的氣象指標數據:sunny,cool,high,TRUE,判斷一下會不會去打球。

outlook temperature humidity windy play
sunny hot high false no
sunny hot high true no
overcast hot high false yes
rainy mild high false yes
rainy cool normal false yes
rainy cool normal true no
overcast cool normal true yes
sunny mild high false no
sunny cool normal false yes
rainy mild normal false yes
sunny mild normal true yes
overcast mild high true yes
overcast hot normal false yes
rainy mild high true no

 

這個問題當然可以用樸素貝葉斯法求解,分別計算在給定天氣條件下打球和不打球的概率,選概率大者作爲推測結果。

現在我們使用ID3歸納決策樹的方法來求解該問題。

預備知識:信息熵

熵是無序性(或不確定性)的度量指標。假如事件A的全概率劃分是(A1,A2,...,An),每部分發生的概率是(p1,p2,...,pn),那信息熵定義爲:

通常以2爲底數,所以信息熵的單位是bit。

補充兩個對數去處公式:

ID3算法

構造樹的基本想法是隨着樹深度的增加,節點的熵迅速地降低。熵降低的速度越快越好,這樣我們有望得到一棵高度最矮的決策樹。

在沒有給定任何天氣信息時,根據歷史數據,我們只知道新的一天打球的概率是9/14,不打的概率是5/14。此時的熵爲:

屬性有4個:outlook,temperature,humidity,windy。我們首先要決定哪個屬性作樹的根節點。

對每項指標分別統計:在不同的取值下打球和不打球的次數。

下面我們計算當已知變量outlook的值時,信息熵爲多少。

outlook=sunny時,2/5的概率打球,3/5的概率不打球。entropy=0.971

outlook=overcast時,entropy=0

outlook=rainy時,entropy=0.971

而根據歷史統計數據,outlook取值爲sunny、overcast、rainy的概率分別是5/14、4/14、5/14,所以當已知變量outlook的值時,信息熵爲:5/14 × 0.971 + 4/14 × 0 + 5/14 × 0.971 = 0.693

這樣的話系統熵就從0.940下降到了0.693,信息增溢gain(outlook)爲0.940-0.693=0.247

同樣可以計算出gain(temperature)=0.029,gain(humidity)=0.152,gain(windy)=0.048。

gain(outlook)最大(即outlook在第一步使系統的信息熵下降得最快),所以決策樹的根節點就取outlook。

接下來要確定N1取temperature、humidity還是windy?在已知outlook=sunny的情況,根據歷史數據,我們作出類似table 2的一張表,分別計算gain(temperature)、gain(humidity)和gain(windy),選最大者爲N1。

依此類推,構造決策樹。當系統的信息熵降爲0時,就沒有必要再往下構造決策樹了,此時葉子節點都是純的--這是理想情況。最壞的情況下,決策樹的高度爲屬性(決策變量)的個數,葉子節點不純(這意味着我們要以一定的概率來作出決策)。

Java實現

最終的決策樹保存在了XML中,使用了Dom4J,注意如果要讓Dom4J支持按XPath選擇節點,還得引入包jaxen.jar。程序代碼要求輸入文件滿足ARFF格式,並且屬性都是標稱變量。

實驗用的數據文件:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
@relation weather.symbolic 
@attribute outlook {sunny, overcast, rainy} 
@attribute temperature {hot, mild, cool} 
@attribute humidity {high, normal} 
@attribute windy {TRUE, FALSE} 
@attribute play {yes, no} 
   
@data 
sunny,hot,high,FALSE,no 
sunny,hot,high,TRUE,no 
overcast,hot,high,FALSE,yes 
rainy,mild,high,FALSE,yes 
rainy,cool,normal,FALSE,yes 
rainy,cool,normal,TRUE,no 
overcast,cool,normal,TRUE,yes 
sunny,mild,high,FALSE,no 
sunny,cool,normal,FALSE,yes 
rainy,mild,normal,FALSE,yes 
sunny,mild,normal,TRUE,yes 
overcast,mild,high,TRUE,yes 
overcast,hot,normal,FALSE,yes 
rainy,mild,high,TRUE,no

程序代碼:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
package com.dfsj;
import java.io.BufferedReader; 
import java.io.File; 
import java.io.FileReader; 
import java.io.FileWriter; 
import java.io.IOException; 
import java.util.ArrayList; 
import java.util.Iterator; 
import java.util.LinkedList; 
import java.util.List; 
import java.util.regex.Matcher; 
import java.util.regex.Pattern; 
   
import org.dom4j.Document; 
import org.dom4j.DocumentHelper; 
import org.dom4j.Element; 
import org.dom4j.io.OutputFormat; 
import org.dom4j.io.XMLWriter; 
   
public class ID3 { 
    private ArrayList<String> attribute = new ArrayList<String>(); // 存儲屬性的名稱 
    private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); // 存儲每個屬性的取值 
    private ArrayList<String[]> data = new ArrayList<String[]>();; // 原始數據 
    int decatt; // 決策變量在屬性集中的索引 
    public static final String patternString = "@attribute(.*)[{](.*?)[}]"
   
    Document xmldoc; 
    Element root; 
   
    public ID3() { 
        xmldoc = DocumentHelper.createDocument(); 
        root = xmldoc.addElement("root"); 
        root.addElement("DecisionTree").addAttribute("value""null"); 
    
   
    public static void main(String[] args) { 
        ID3 inst = new ID3(); 
        inst.readARFF(new File("/home/orisun/test/weather.nominal.arff")); 
        inst.setDec("play"); 
        LinkedList<Integer> ll=new LinkedList<Integer>(); 
        for(int i=0;i<inst.attribute.size();i++){ 
            if(i!=inst.decatt) 
                ll.add(i); 
        
        ArrayList<Integer> al=new ArrayList<Integer>(); 
        for(int i=0;i<inst.data.size();i++){ 
            al.add(i); 
        
        inst.buildDT("DecisionTree""null", al, ll); 
        inst.writeXML("/home/orisun/test/dt.xml"); 
        return
    
   
    //讀取arff文件,給attribute、attributevalue、data賦值 
    public void readARFF(File file) { 
        try 
            FileReader fr = new FileReader(file); 
            BufferedReader br = new BufferedReader(fr); 
            String line; 
            Pattern pattern = Pattern.compile(patternString); 
            while ((line = br.readLine()) != null) { 
                Matcher matcher = pattern.matcher(line); 
                if (matcher.find()) { 
                    attribute.add(matcher.group(1).trim()); 
                    String[] values = matcher.group(2).split(","); 
                    ArrayList<String> al = new ArrayList<String>(values.length); 
                    for (String value : values) { 
                        al.add(value.trim()); 
                    
                    attributevalue.add(al); 
                else if (line.startsWith("@data")) { 
                    while ((line = br.readLine()) != null) { 
                        if(line==""
                            continue
                        String[] row = line.split(","); 
                        data.add(row); 
                    
                else 
                    continue
                
            
            br.close(); 
        catch (IOException e1) { 
            e1.printStackTrace(); 
        
    
   
    //設置決策變量 
    public void setDec(int n) { 
        if (n < 0 || n >= attribute.size()) { 
            System.err.println("決策變量指定錯誤。"); 
            System.exit(2); 
        
        decatt = n; 
    
    public void setDec(String name) { 
        int n = attribute.indexOf(name); 
        setDec(n); 
    
   
    //給一個樣本(數組中是各種情況的計數),計算它的熵 
    public double getEntropy(int[] arr) { 
        double entropy = 0.0
        int sum = 0
        for (int i = 0; i < arr.length; i++) { 
            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2); 
            sum += arr[i]; 
        
        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2); 
        entropy /= sum; 
        return entropy; 
    
   
    //給一個樣本數組及樣本的算術和,計算它的熵 
    public double getEntropy(int[] arr, int sum) { 
        double entropy = 0.0
        for (int i = 0; i < arr.length; i++) { 
            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2); 
        
        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2); 
        entropy /= sum; 
        return entropy; 
    
   
    public boolean infoPure(ArrayList<Integer> subset) { 
        String value = data.get(subset.get(0))[decatt]; 
        for (int i = 1; i < subset.size(); i++) { 
            String next=data.get(subset.get(i))[decatt]; 
            //equals表示對象內容相同,==表示兩個對象指向的是同一片內存 
            if (!value.equals(next)) 
                return false
        
        return true
    
   
    // 給定原始數據的子集(subset中存儲行號),當以第index個屬性爲節點時計算它的信息熵 
    public double calNodeEntropy(ArrayList<Integer> subset, int index) { 
        int sum = subset.size(); 
        double entropy = 0.0
        int[][] info = new int[attributevalue.get(index).size()][]; 
        for (int i = 0; i < info.length; i++) 
            info[i] = new int[attributevalue.get(decatt).size()]; 
        int[] count = new int[attributevalue.get(index).size()]; 
        for (int i = 0; i < sum; i++) { 
            int n = subset.get(i); 
            String nodevalue = data.get(n)[index]; 
            int nodeind = attributevalue.get(index).indexOf(nodevalue); 
            count[nodeind]++; 
            String decvalue = data.get(n)[decatt]; 
            int decind = attributevalue.get(decatt).indexOf(decvalue); 
            info[nodeind][decind]++; 
        
        for (int i = 0; i < info.length; i++) { 
            entropy += getEntropy(info[i]) * count[i] / sum; 
        
        return entropy; 
    
   
    // 構建決策樹 
    public void buildDT(String name, String value, ArrayList<Integer> subset, 
            LinkedList<Integer> selatt) { 
        Element ele = null
        @SuppressWarnings("unchecked"
        List<Element> list = root.selectNodes("//"+name); 
        Iterator<Element> iter=list.iterator(); 
        while(iter.hasNext()){ 
            ele=iter.next(); 
            if(ele.attributeValue("value").equals(value)) 
                break
        
        if (infoPure(subset)) { 
            ele.setText(data.get(subset.get(0))[decatt]); 
            return
        
        int minIndex = -1
        double minEntropy = Double.MAX_VALUE; 
        for (int i = 0; i < selatt.size(); i++) { 
            if (i == decatt) 
                continue
            double entropy = calNodeEntropy(subset, selatt.get(i)); 
            if (entropy < minEntropy) { 
                minIndex = selatt.get(i); 
                minEntropy = entropy; 
            
        
        String nodeName = attribute.get(minIndex); 
        selatt.remove(new Integer(minIndex)); 
        ArrayList<String> attvalues = attributevalue.get(minIndex); 
        for (String val : attvalues) { 
            ele.addElement(nodeName).addAttribute("value", val); 
            ArrayList<Integer> al = new ArrayList<Integer>(); 
            for (int i = 0; i < subset.size(); i++) { 
                if (data.get(subset.get(i))[minIndex].equals(val)) { 
                    al.add(subset.get(i)); 
                
            
            buildDT(nodeName, val, al, selatt); 
        
    
   
    // 把xml寫入文件 
    public void writeXML(String filename) { 
        try 
            File file = new File(filename); 
            if (!file.exists()) 
                file.createNewFile(); 
            FileWriter fw = new FileWriter(file); 
            OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式 
            XMLWriter output = new XMLWriter(fw, format); 
            output.write(xmldoc); 
            output.close(); 
        catch (IOException e) { 
            System.out.println(e.getMessage()); 
        
    
}

 

最終生成的文件如下:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
<?xml version="1.0" encoding="UTF-8"?> 
   
<root> 
  <DecisionTree value="null"
    <outlook value="sunny"
      <humidity value="high">no</humidity> 
      <humidity value="normal">yes</humidity> 
    </outlook> 
    <outlook value="overcast">yes</outlook> 
    <outlook value="rainy"
      <windy value="TRUE">no</windy> 
      <windy value="FALSE">yes</windy> 
    </outlook> 
  </DecisionTree> 
</root>

 

用圖形象地表示就是:

 

 原文地址:http://my.oschina.net/dfsj66011/blog/343647?fromerr=3XqeFN8R

發佈了1 篇原創文章 · 獲贊 8 · 訪問量 7萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章