手写决策树并可视化

决策树

可视化

在这里插入图片描述

描述

采用数据为UCI数据库中的Lenses Data Set(https://archive.ics.uci.edu/ml/datasets/Lenses)

包含
24个实例
3个分类:
1 : the patient should be fitted with hard contact lenses,
2 : the patient should be fitted with soft contact lenses,
3 : the patient should not be fitted with contact lenses.

4种属性:
1:age of the patient: (1) young, (2) pre-presbyopic, (3) presbyopic
2:spectacle prescription: (1) myope, (2) hypermetrope
3:astigmatic: (1) no, (2) yes
4:tear production rate: (1) reduced, (2) normal

源数据格式:

1 1 1 1 1 3
2 1 1 1 2 2
3 1 1 2 1 3

20 3 1 2 2 1
21 3 2 1 1 3
22 3 2 1 2 2
23 3 2 2 1 3
24 3 2 2 2 3

理论方法

采用递归的方法向下生成决策树
结束条件:

1.若节点内的所有实例都属于同一类
2.若节点内的所有实例的属性都相同

主要操作:
对节点内的属性进行划分,并判断若按此属性划分的信息增益
找到信息增益最大的属性,并进行划分
对划分后子节点递归该操作

实现方法

采用python作为语言进行编写
创建DT类
节点的数据结构为一个数组:[id, class,samples, ent,attr, children]
children为子节点的id数组,若为空则为叶子节点

可视化
可视化采用dot language,对决策树遍历后按格式写入.dot文件中,最后导出图像文件

附录

import math
import numpy as np
class DT():
    def __init__(self):
        self.nodes = []# [self, class,samples, ent,attrs, children]
        self.node_index = []
        self.node_n = 0
        pass
    def treeGeneration(self,D):
        class_n = list(set([i[-1] for i in D]))
        if len(class_n) == 1:
            self.nodes.append([self.node_n,class_n[0],len(D), 0, -1, []])
            self.node_index.append(self.node_n)
            self.node_n += 1
            return self.node_n - 1
        temp_i = D[0]
        for i in D:
            if i[:-1] != temp_i[:-1]:
                a_n = len(i) - 1
                gains=[]
                for i in range(0,a_n):
                    class_a_n = set([j[i] for j in D])
                    if len(class_a_n) == 1:
                        gains.append(0)
                    else:
                        d = [[] for m in range(0,len(class_a_n))]
                        for m in D:
                            d[list(class_a_n).index(m[i])].append(m)
                        gains.append(self.gain(D,d))
                best_a = np.argmax(gains)
                a_s = set([j[best_a] for j in D])
                d = [[] for m in range(0, len(a_s))]
                for m in D:
                    d[list(a_s).index(m[best_a])].append(m)
                children_index = []
                self.nodes.append([self.node_n, -1, len(D), self.ent(D), best_a, []])
                father_node = self.node_n
                self.node_index.append(self.node_n)
                self.node_n += 1
                for i in d:
                    children_index.append(self.treeGeneration(i))
                self.nodes[father_node][5] = children_index
                return father_node
        self.nodes.append([self.node_n,max([[i[-1] for i in D].count(j) for j in class_n]), len(D),self.ent(D),-1, []])
        self.node_index.append(self.node_n)
        self.node_n += 1
        return self.node_n - 1
    def ent(self,D):
        class_n = set([i[-1] for i in D])
        len_D = len(D)
        len_d = {}
        for i in class_n:
            len_d[i] = 0
        for i in D:
            len_d[i[-1]] += 1
        p = [i/len_D for i in len_d.values()]
        return self.ent_(p)
    def ent_(self,p):
        return -sum([i*math.log(i,2) for i in p])
    def gain(self,D,d):
        D_len = len(D)
        return self.ent(D) - sum([(len(i)/D_len)*self.ent(i) for i in d])
    def fit(self, data):
        self.treeGeneration(data)
        return
    def to_graph(self, filepath,attrs=None,classes=None):
        with open(filepath, "w") as dot_f:
            dot_f.write("digraph dtnodes{\n")
            # nodes definition
            for i in range(0,self.node_n):
                    dot_f.write("{}[label=\"class:{}\nsamples:{}\nentropy:{}\"];\n".format(i,self.nodes[i][1] if classes is None or self.nodes[i][1] == -1 else classes[self.nodes[i][1]] ,self.nodes[i][2],self.nodes[i][3]))
            # arcs
            for i in self.nodes:
                node_index = i[0]
                num = 1
                for j in i[5]:
                    dot_f.write("{}->{}[label=\"{}\"];\n".format(node_index, j,j if attrs is None else attrs[i[4]]+"="+str(num)))
                    num+=1
            dot_f.write("}")
         
data = []
with open("./lenses.data", "r") as f:
    for i in f:
        data.append([int(j) for j in i.replace("  "," ").split(" ")[1:]])
# print(data)
dt = DT()
dt.fit(data)
dt.to_graph("./demo.dot",attrs={0:"age",1:"spectacle prescription",2:"astigmatic",3:"tear production rate"},classes={1:"hard",2:"soft",3:"not"})
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章