周志华《机器学习》习题4.3

为表4.3中数据生成一棵决策树。

代码是在《机器学习实战》的代码基础上改良的,借用了numpy, pandas之后明显简化了代码。表4.3的数据特征是离散属性和连续属性都有,问题就复杂在这里。话不多说,看代码。
先定义几个辅助函数,正常的思路是先想宏观算法,然后需要什么函数就定义什么函数。

import math
import pandas as pd
import numpy as np

from treePlotter import createPlot
def entropy(data):
    label_values = data[data.columns[-1]]
    #Returns object containing counts of unique values.
    counts =  label_values.value_counts()
    s = 0
    for c in label_values.unique():
        freq = float(counts[c])/len(label_values) 
        s -= freq*math.log(freq,2)
    return s

def is_continuous(data,attr):
    """Check if attr is a continuous attribute"""
    return data[attr].dtype == 'float64'

def split_points(data,attr):
    """Returns Ta,Equation(4.7),p.84"""
    values = np.sort(data[attr].values)
    return [(x+y)/2 for x,y in zip(values[:-1],values[1:])] 

treePlotter是《实战》里的模块,用来把决策树画出来。这里决策树是用字典表示的,key可以表示树的节点或分枝,表示节点的时候是属性,表示分枝的时候是属性值。value又是一个字典或字符串,是字符串的时候表示叶,也就是标记。这里的data是pandas里的DataFrame,形式上像一个表,对表的常见操作它都可以方便的解决。命名习惯跟书上一致。

再继续看怎么计算信息增益:

def discrete_gain(data,attr):
    V = data[attr].unique()
    s = 0
    for v in V:
        data_v = data[data[attr]== v]
        s += float(len(data_v))/len(data)*entropy(data_v)
    return (entropy(data) - s,None)

def continuous_gain(data,attr,points):
    """Equation(4.8),p.84,returns the max gain along with its splitting point"""
    entD = entropy(data)
    #gains is a list of pairs of the form (gain,t)
    gains = []
    for t in points:
        d_plus = data[data[attr] > t]
        d_minus = data[data[attr] <= t]
        gain = entD - (float(len(d_plus))/len(data)*entropy(d_plus)+float(len(d_minus))/len(data)*entropy(d_minus))
        gains.append((gain,t))
    return max(gains)

离散属性的信息增益一目了然,最后返回的pair中的None是为了给后面的函数判断之用,看到None就知道是离散属性了。连续属性的信息增益的计算方法是对每个划分点t 都计算一下增益,然后连同t 一起存到一个链表里,最后取最大的那个。
然后就是统管的信息增益函数:

def gain(data,attr):
    if is_continuous(data,attr):
        points = split_points(data,attr)
        return continuous_gain(data,attr,points)
    else:
        return discrete_gain(data,attr)

还要用到一个众数函数:

def majority(label_values):
    counts = label_values.value_counts()
    return counts.index[0]

我们的id3终于登场了:

def id3(data):
    attrs = data.columns[:-1]
    #attrWithGain is of the form [(attr,(gain,t))], t is None if attr is discrete
    attrWithGain = [(a,gain(data,a)) for a in attrs] 
    attrWithGain.sort(key = lambda tup:tup[1][0],reverse = True)
    return attrWithGain[0]

它对每个属性都计算了信息增益,最后返回信息增益最大的那个属性,连带两个附加值,形式是(attr,(gain,t))。

最后造树:

def createTree(data,split_function):
    label = data.columns[-1]
    label_values = data[label]
    #Stop when all classes are equal
    if len(label_values.unique()) == 1:
        return label_values.values[0]
    #When no more features, or only one feature with same values, return majority
    if data.shape[1] == 1 or (data.shape[1]==2 and len(data.T.ix[0].unique())==1):
        return majority(label_values)
    bestAttr,(g,t) = split_function(data)
    #If bestAttr is discrete
    if t is None:
        #In this tree,a key is a node, the value is a list of trees,also a dictionary
        myTree = {bestAttr:{}}
        values = data[bestAttr].unique() 
        for v in values:
            data_v = data[data[bestAttr]== v]
            attrsAndLabel = data.columns.tolist()
            attrsAndLabel.remove(bestAttr)
            data_v = data_v[attrsAndLabel]
            myTree[bestAttr][v] = createTree(data_v,split_function)
        return myTree
    #If bestAttr is continuous
    else:
        t = round(t,3)
        node = bestAttr+'<='+str(t)
        myTree = {node:{}}
        values = ['yes','no']
        for v in values:
            data_v = data[data[bestAttr] <= t] if v == 'yes' else data[data[bestAttr] > t]
            myTree[node][v] = createTree(data_v,split_function)
        return myTree

这个我就不细说了,还得自己看。值得一提的是离散属性的下一次递归把当前的离散值删掉了,attrsAndLabel.remove(bestAttr),因为不允许这个属性出现在后续的分枝中。然而连续属性的时候,不删,允许继续出现。这个好理解,毕竟对连续属性用的是二分法,可能需要多个二分才能把情况搞清。

测试一下:

if __name__ == "__main__":
    f = pd.read_csv(filepath_or_buffer = 'dataset/watermelon3.0en.csv', sep = ',')
    data = f[f.columns[1:]]

    tree = createTree(data,id3)
    print tree
    createPlot(tree)

我把原表翻译成英文了,因为中文的打印字典不显示汉字,画图的时候甚至直接不能画。

id,color,root,knock,texture,umbilical,touch,density,sugar content,good melon
1,green,curled up,cloudy,clear,concave,hard slip,0.697,0.46,yes
2,black,curled up,dull,clear,concave,hard slip,0.774,0.376,yes
3,black,curled up,cloudy,clear,concave,hard slip,0.634,0.264,yes
4,green,curled up,dull,clear,concave,hard slip,0.608,0.318,yes
5,pale,curled up,cloudy,clear,concave,hard slip,0.556,0.215,yes
6,green,slightly curled,cloudy,clear,slightly concave,soft sticky,0.403,0.237,yes
7,black,slightly curled,cloudy,slightly paste,slightly concave,soft sticky,0.481,0.149,yes
8,black,slightly curled,cloudy,clear,slightly concave,hard slip,0.437,0.211,yes
9,black,slightly curled,dull,slightly paste,slightly concave,hard slip,0.666,0.091,no
10,green,stiff,crisp,clear,flat,soft sticky,0.243,0.267,no
11,pale,stiff,crisp,fuzzy,flat,hard slip,0.245,0.057,no
12,pale,curled up,cloudy,fuzzy,flat,soft sticky,0.343,0.099,no
13,green,slightly curled,cloudy,slightly paste,concave,hard slip,0.639,0.161,no
14,pale,slightly curled,dull,slightly paste,concave,hard slip,0.657,0.198,no
15,black,slightly curled,cloudy,clear,slightly concave,soft sticky,0.36,0.37,no
16,pale,curled up,cloudy,fuzzy,flat,hard slip,0.593,0.042,no
17,green,curled up,dull,slightly paste,slightly concave,hard slip,0.719,0.103,no

treePlotter我就不放上来了,委屈大家看一下字典凑合下吧。
画出来的树跟书上图4.8一样:

在西瓜数据集3.0上基于信息增益生成的决策树

把代码按照顺序复制到编辑器,保存下就可以运行了,记得吧treePlotter注释掉。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章