周志華《機器學習》習題4.3

爲表4.3中數據生成一棵決策樹。

代碼是在《機器學習實戰》的代碼基礎上改良的,借用了numpy, pandas之後明顯簡化了代碼。表4.3的數據特徵是離散屬性和連續屬性都有,問題就複雜在這裏。話不多說,看代碼。
先定義幾個輔助函數,正常的思路是先想宏觀算法,然後需要什麼函數就定義什麼函數。

import math
import pandas as pd
import numpy as np

from treePlotter import createPlot
def entropy(data):
    label_values = data[data.columns[-1]]
    #Returns object containing counts of unique values.
    counts =  label_values.value_counts()
    s = 0
    for c in label_values.unique():
        freq = float(counts[c])/len(label_values) 
        s -= freq*math.log(freq,2)
    return s

def is_continuous(data,attr):
    """Check if attr is a continuous attribute"""
    return data[attr].dtype == 'float64'

def split_points(data,attr):
    """Returns Ta,Equation(4.7),p.84"""
    values = np.sort(data[attr].values)
    return [(x+y)/2 for x,y in zip(values[:-1],values[1:])] 

treePlotter是《實戰》裏的模塊,用來把決策樹畫出來。這裏決策樹是用字典表示的,key可以表示樹的節點或分枝,表示節點的時候是屬性,表示分枝的時候是屬性值。value又是一個字典或字符串,是字符串的時候表示葉,也就是標記。這裏的data是pandas裏的DataFrame,形式上像一個表,對錶的常見操作它都可以方便的解決。命名習慣跟書上一致。

再繼續看怎麼計算信息增益:

def discrete_gain(data,attr):
    V = data[attr].unique()
    s = 0
    for v in V:
        data_v = data[data[attr]== v]
        s += float(len(data_v))/len(data)*entropy(data_v)
    return (entropy(data) - s,None)

def continuous_gain(data,attr,points):
    """Equation(4.8),p.84,returns the max gain along with its splitting point"""
    entD = entropy(data)
    #gains is a list of pairs of the form (gain,t)
    gains = []
    for t in points:
        d_plus = data[data[attr] > t]
        d_minus = data[data[attr] <= t]
        gain = entD - (float(len(d_plus))/len(data)*entropy(d_plus)+float(len(d_minus))/len(data)*entropy(d_minus))
        gains.append((gain,t))
    return max(gains)

離散屬性的信息增益一目瞭然,最後返回的pair中的None是爲了給後面的函數判斷之用,看到None就知道是離散屬性了。連續屬性的信息增益的計算方法是對每個劃分點t 都計算一下增益,然後連同t 一起存到一個鏈表裏,最後取最大的那個。
然後就是統管的信息增益函數:

def gain(data,attr):
    if is_continuous(data,attr):
        points = split_points(data,attr)
        return continuous_gain(data,attr,points)
    else:
        return discrete_gain(data,attr)

還要用到一個衆數函數:

def majority(label_values):
    counts = label_values.value_counts()
    return counts.index[0]

我們的id3終於登場了:

def id3(data):
    attrs = data.columns[:-1]
    #attrWithGain is of the form [(attr,(gain,t))], t is None if attr is discrete
    attrWithGain = [(a,gain(data,a)) for a in attrs] 
    attrWithGain.sort(key = lambda tup:tup[1][0],reverse = True)
    return attrWithGain[0]

它對每個屬性都計算了信息增益,最後返回信息增益最大的那個屬性,連帶兩個附加值,形式是(attr,(gain,t))。

最後造樹:

def createTree(data,split_function):
    label = data.columns[-1]
    label_values = data[label]
    #Stop when all classes are equal
    if len(label_values.unique()) == 1:
        return label_values.values[0]
    #When no more features, or only one feature with same values, return majority
    if data.shape[1] == 1 or (data.shape[1]==2 and len(data.T.ix[0].unique())==1):
        return majority(label_values)
    bestAttr,(g,t) = split_function(data)
    #If bestAttr is discrete
    if t is None:
        #In this tree,a key is a node, the value is a list of trees,also a dictionary
        myTree = {bestAttr:{}}
        values = data[bestAttr].unique() 
        for v in values:
            data_v = data[data[bestAttr]== v]
            attrsAndLabel = data.columns.tolist()
            attrsAndLabel.remove(bestAttr)
            data_v = data_v[attrsAndLabel]
            myTree[bestAttr][v] = createTree(data_v,split_function)
        return myTree
    #If bestAttr is continuous
    else:
        t = round(t,3)
        node = bestAttr+'<='+str(t)
        myTree = {node:{}}
        values = ['yes','no']
        for v in values:
            data_v = data[data[bestAttr] <= t] if v == 'yes' else data[data[bestAttr] > t]
            myTree[node][v] = createTree(data_v,split_function)
        return myTree

這個我就不細說了,還得自己看。值得一提的是離散屬性的下一次遞歸把當前的離散值刪掉了,attrsAndLabel.remove(bestAttr),因爲不允許這個屬性出現在後續的分枝中。然而連續屬性的時候,不刪,允許繼續出現。這個好理解,畢竟對連續屬性用的是二分法,可能需要多個二分才能把情況搞清。

測試一下:

if __name__ == "__main__":
    f = pd.read_csv(filepath_or_buffer = 'dataset/watermelon3.0en.csv', sep = ',')
    data = f[f.columns[1:]]

    tree = createTree(data,id3)
    print tree
    createPlot(tree)

我把原表翻譯成英文了,因爲中文的打印字典不顯示漢字,畫圖的時候甚至直接不能畫。

id,color,root,knock,texture,umbilical,touch,density,sugar content,good melon
1,green,curled up,cloudy,clear,concave,hard slip,0.697,0.46,yes
2,black,curled up,dull,clear,concave,hard slip,0.774,0.376,yes
3,black,curled up,cloudy,clear,concave,hard slip,0.634,0.264,yes
4,green,curled up,dull,clear,concave,hard slip,0.608,0.318,yes
5,pale,curled up,cloudy,clear,concave,hard slip,0.556,0.215,yes
6,green,slightly curled,cloudy,clear,slightly concave,soft sticky,0.403,0.237,yes
7,black,slightly curled,cloudy,slightly paste,slightly concave,soft sticky,0.481,0.149,yes
8,black,slightly curled,cloudy,clear,slightly concave,hard slip,0.437,0.211,yes
9,black,slightly curled,dull,slightly paste,slightly concave,hard slip,0.666,0.091,no
10,green,stiff,crisp,clear,flat,soft sticky,0.243,0.267,no
11,pale,stiff,crisp,fuzzy,flat,hard slip,0.245,0.057,no
12,pale,curled up,cloudy,fuzzy,flat,soft sticky,0.343,0.099,no
13,green,slightly curled,cloudy,slightly paste,concave,hard slip,0.639,0.161,no
14,pale,slightly curled,dull,slightly paste,concave,hard slip,0.657,0.198,no
15,black,slightly curled,cloudy,clear,slightly concave,soft sticky,0.36,0.37,no
16,pale,curled up,cloudy,fuzzy,flat,hard slip,0.593,0.042,no
17,green,curled up,dull,slightly paste,slightly concave,hard slip,0.719,0.103,no

treePlotter我就不放上來了,委屈大家看一下字典湊合下吧。
畫出來的樹跟書上圖4.8一樣:

在西瓜數據集3.0上基於信息增益生成的決策樹

把代碼按照順序複製到編輯器,保存下就可以運行了,記得吧treePlotter註釋掉。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章