手寫決策樹並可視化

決策樹

可視化

在這裏插入圖片描述

描述

採用數據爲UCI數據庫中的Lenses Data Set(https://archive.ics.uci.edu/ml/datasets/Lenses)

包含
24個實例
3個分類:
1 : the patient should be fitted with hard contact lenses,
2 : the patient should be fitted with soft contact lenses,
3 : the patient should not be fitted with contact lenses.

4種屬性:
1:age of the patient: (1) young, (2) pre-presbyopic, (3) presbyopic
2:spectacle prescription: (1) myope, (2) hypermetrope
3:astigmatic: (1) no, (2) yes
4:tear production rate: (1) reduced, (2) normal

源數據格式:

1 1 1 1 1 3
2 1 1 1 2 2
3 1 1 2 1 3

20 3 1 2 2 1
21 3 2 1 1 3
22 3 2 1 2 2
23 3 2 2 1 3
24 3 2 2 2 3

理論方法

採用遞歸的方法向下生成決策樹
結束條件:

1.若節點內的所有實例都屬於同一類
2.若節點內的所有實例的屬性都相同

主要操作:
對節點內的屬性進行劃分,並判斷若按此屬性劃分的信息增益
找到信息增益最大的屬性,並進行劃分
對劃分後子節點遞歸該操作

實現方法

採用python作爲語言進行編寫
創建DT類
節點的數據結構爲一個數組:[id, class,samples, ent,attr, children]
children爲子節點的id數組,若爲空則爲葉子節點

可視化
可視化採用dot language,對決策樹遍歷後按格式寫入.dot文件中,最後導出圖像文件

附錄

import math
import numpy as np
class DT():
    def __init__(self):
        self.nodes = []# [self, class,samples, ent,attrs, children]
        self.node_index = []
        self.node_n = 0
        pass
    def treeGeneration(self,D):
        class_n = list(set([i[-1] for i in D]))
        if len(class_n) == 1:
            self.nodes.append([self.node_n,class_n[0],len(D), 0, -1, []])
            self.node_index.append(self.node_n)
            self.node_n += 1
            return self.node_n - 1
        temp_i = D[0]
        for i in D:
            if i[:-1] != temp_i[:-1]:
                a_n = len(i) - 1
                gains=[]
                for i in range(0,a_n):
                    class_a_n = set([j[i] for j in D])
                    if len(class_a_n) == 1:
                        gains.append(0)
                    else:
                        d = [[] for m in range(0,len(class_a_n))]
                        for m in D:
                            d[list(class_a_n).index(m[i])].append(m)
                        gains.append(self.gain(D,d))
                best_a = np.argmax(gains)
                a_s = set([j[best_a] for j in D])
                d = [[] for m in range(0, len(a_s))]
                for m in D:
                    d[list(a_s).index(m[best_a])].append(m)
                children_index = []
                self.nodes.append([self.node_n, -1, len(D), self.ent(D), best_a, []])
                father_node = self.node_n
                self.node_index.append(self.node_n)
                self.node_n += 1
                for i in d:
                    children_index.append(self.treeGeneration(i))
                self.nodes[father_node][5] = children_index
                return father_node
        self.nodes.append([self.node_n,max([[i[-1] for i in D].count(j) for j in class_n]), len(D),self.ent(D),-1, []])
        self.node_index.append(self.node_n)
        self.node_n += 1
        return self.node_n - 1
    def ent(self,D):
        class_n = set([i[-1] for i in D])
        len_D = len(D)
        len_d = {}
        for i in class_n:
            len_d[i] = 0
        for i in D:
            len_d[i[-1]] += 1
        p = [i/len_D for i in len_d.values()]
        return self.ent_(p)
    def ent_(self,p):
        return -sum([i*math.log(i,2) for i in p])
    def gain(self,D,d):
        D_len = len(D)
        return self.ent(D) - sum([(len(i)/D_len)*self.ent(i) for i in d])
    def fit(self, data):
        self.treeGeneration(data)
        return
    def to_graph(self, filepath,attrs=None,classes=None):
        with open(filepath, "w") as dot_f:
            dot_f.write("digraph dtnodes{\n")
            # nodes definition
            for i in range(0,self.node_n):
                    dot_f.write("{}[label=\"class:{}\nsamples:{}\nentropy:{}\"];\n".format(i,self.nodes[i][1] if classes is None or self.nodes[i][1] == -1 else classes[self.nodes[i][1]] ,self.nodes[i][2],self.nodes[i][3]))
            # arcs
            for i in self.nodes:
                node_index = i[0]
                num = 1
                for j in i[5]:
                    dot_f.write("{}->{}[label=\"{}\"];\n".format(node_index, j,j if attrs is None else attrs[i[4]]+"="+str(num)))
                    num+=1
            dot_f.write("}")
         
data = []
with open("./lenses.data", "r") as f:
    for i in f:
        data.append([int(j) for j in i.replace("  "," ").split(" ")[1:]])
# print(data)
dt = DT()
dt.fit(data)
dt.to_graph("./demo.dot",attrs={0:"age",1:"spectacle prescription",2:"astigmatic",3:"tear production rate"},classes={1:"hard",2:"soft",3:"not"})
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章