Python手寫決策樹算法

數據集準備

web站點的用戶在線瀏覽行爲及最終購買行爲,每個用戶的在線瀏覽行爲信息包括:每個用戶的來源網站、用戶的ip位置、是否閱讀FAQ、瀏覽網頁數目。目標分類爲用戶類型:遊客、基本用戶、高級用戶

算法 支持模型 數結構 特徵選擇 連續值處理 缺失值處理 剪枝
ID3 分類 多叉樹 信息增益 不支持 不支持 不支持
C4.5 分類 多叉樹 信息增益比 支持 支持 支持
CART 分類、迴歸 二叉樹 基尼指數、均方差 支持 支持 支持
my_data=[['slashdot','USA','yes',18,'None'],
        ['google','France','yes',23,'Premium'],
        ['digg','USA','yes',24,'Basic'],
        ['kiwitobes','France','yes',23,'Basic'],
        ['google','UK','no',21,'Premium'],
        ['(direct)','New Zealand','no',12,'None'],
        ['(direct)','UK','no',21,'Basic'],
        ['google','USA','no',24,'Premium'],
        ['slashdot','France','yes',19,'None'],
        ['digg','USA','no',18,'None'],
        ['google','UK','no',18,'None'],
        ['kiwitobes','UK','no',19,'None'],
        ['digg','New Zealand','yes',12,'Basic'],
        ['slashdot','UK','no',21,'None'],
        ['google','UK','yes',18,'Basic'],
        ['kiwitobes','France','yes',19,'Basic']]

數據集第一列用戶網站來源,有slashdot、google、digg、kiwitobes、(direct)五種;第二列是用戶IP的位置,有USA、UK、France和New Zealand;第三列是否閱讀FAQ,yes或no;第四列是網頁瀏覽數目,有12、18、19、21、23、24;最後一列是標籤,代表瀏覽用戶類型,有遊客None、基本用戶Basic、高級用戶Premium。

決策樹前的準備工作

採用樹結構:二叉樹
選擇屬性準則:信息增益
目的:分類

  • 算法思想:

生成決策樹函數(訓練樣本)
if 訓練樣本爲空 : 返回一個空結點
end if 訓練樣本信息增益爲0或屬性所有取值相等:返回葉子結點
else:
——for 某一個屬性 in 訓練樣本:#遍歷每一個屬性
———— for 某屬性的某個值 in 某一個屬性:
——————根據某屬性的某個值,將訓練樣本劃分爲兩個,計算信息增益
————end for
——end for
——選擇最大的信息增益,將這個屬性按這個值進行左右分支,並回調這個函數兩次。

  • 代碼模塊功能

uniquecounts(rows):對訓練樣本的每一個可能的標籤結果進行統計,返回dict類型。
entropy(rows):配合uniquecounts方法,進行信息熵的計算,返回float值。
divideset(rows,column,value):給定訓練樣本、第幾列、某列中的某個值,效果是將訓練樣本以某列的某個值一分爲二,返回倆個訓練樣本。
class decisionnode():初始化結點功能,用於構建樹。
buildtree(rows,scoref = entropy):核心函數,參照算法思想的邏輯構寫,遞歸生成結點。

完整代碼

from math import log
def divideset(rows,column,value):
    #對於給定的集合,指定列數和這列某一個值
    if isinstance(value,int) or isinstance(value,float):
        split_function = lambda row:row[column] >= value
    else:
        split_function = lambda row:row[column]==value
    set1 = [row for row in rows if split_function(row)]
    set2 = [row for row in rows if not split_function(row)]
    return(set1,set2)
#定義節點的屬性
class decisionnode:
    def __init__(self,col = -1,value = None, results = None, tb = None,fb = None):
        self.col = col   # col是待檢驗的判斷條件所對應的列索引值
        self.value = value # value對應於爲了使結果爲True,當前列必須匹配的值
        self.results = results #保存的是針對當前分支的結果,它是一個字典
        self.tb = tb ## desision node,右子樹,True對應
        self.fb = fb ## desision node,左子樹,False對應
# 對y的各種可能的取值出現的個數進行計數.。其他函數利用該函數來計算數據集和的混雜程度例如{'Basic': 6, 'None': 7, 'Premium': 3}
def uniquecounts(rows):
    results = {}
    for row in rows:
        #計數結果在最後一列
        r = row[len(row)-1]
        if r not in results:
            results[r] = 0
        results[r]+=1
    return results # 返回一個字典
def entropy(rows):
    """
    計算當前集合的信息熵H(D)
    """  
    log2 = lambda x:log(x)/log(2)
#     log2 = lambda x:log(2,x)
    results = uniquecounts(rows)
    #開始計算熵的值
    ent = 0.0
    for r in results.keys():
        p = float(results[r])/len(rows)
        ent -=  p*log2(p)
    return ent
# 以遞歸方式構造樹

def buildtree(rows,scoref = entropy):
    if len(rows)==0 : return decisionnode()
    current_score = scoref(rows)
    
    # 定義一些變量以記錄最佳拆分條件
    best_gain = 0.0
    best_criteria = None
    best_sets = None
    
    column_count = len(rows[0]) - 1
    for col in range(0,column_count): #遍歷所有屬性列
        #在當前列中生成一個由不同值構成的序列
        column_values = set({})  #用集合存放當前列有哪些不同值
        for row in rows:
            column_values.add(row[col]) 
#         print(column_values)
        #根據這一列中的每個值,嘗試對數據集進行拆分
        for value in column_values:  #某個屬性可能的所有值,比如第一列:(direct)、'digg'、 'google'、 'kiwitobes'、 'slashdot'
            (set1,set2) = divideset(rows,col,value)
            
            # 信息增益
            p = float(len(set1))/len(rows)
            gain = current_score - p*scoref(set1) - (1-p)*scoref(set2)
            if gain>best_gain and len(set1)>0 and len(set2)>0:
                best_gain = gain
                best_criteria = (col,value)  #結點分裂的屬性
                best_sets = (set1,set2)
                
    #創建子分支
    if best_gain>0:
        trueBranch = buildtree(best_sets[0])  #遞歸調用
        falseBranch = buildtree(best_sets[1])
        return decisionnode(col = best_criteria[0],value = best_criteria[1],
                            tb = trueBranch,fb = falseBranch)
    else:
        print(uniquecounts(rows))
        return decisionnode(results = uniquecounts(rows)) #{'Basic': 6, 'None': 7, 'Premium': 3}
tree = buildtree(my_data)

畫出生成的二叉樹

#繪製決策樹
from PIL import Image, ImageDraw

# 獲取樹的顯示寬度,即有多少個葉子節點
def getwidth(tree):
    if tree.tb==None and tree.fb==None: return 1
    return getwidth(tree.tb)+getwidth(tree.fb)

# 獲取樹的顯示深度(高度)
def getdepth(tree):
    if tree.tb==None and tree.fb==None: return 0
    return max(getdepth(tree.tb),getdepth(tree.fb))+1

# 繪製樹形圖
def drawtree(tree,jpeg='tree.jpg'):
    w=getwidth(tree)*100
    h=getdepth(tree)*100+120

    img=Image.new('RGB',(w,h),(255,255,255))
    draw=ImageDraw.Draw(img)

    drawnode(draw,tree,w/2,20)  #根節點座標
    img.save(jpeg,'JPEG')

# 迭代畫樹的節點
def drawnode(draw,tree,x,y):
    if tree.results==None:
        # 得到每個分支的寬度
        w1=getwidth(tree.fb)*100
        w2=getwidth(tree.tb)*100

        # 確定此節點所要佔據的總空間
        left=x-(w1+w2)/2
        right=x+(w1+w2)/2

        # 繪製判斷條件字符串
        draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))

        # 繪製到分支的連線
        draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
        draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))

        # 繪製分支的節點
        drawnode(draw,tree.fb,left+w1/2,y+100)
        drawnode(draw,tree.tb,right-w2/2,y+100)
    else:
        txt=' \n'.join(['%s:%d'%v for v in tree.results.items()])
        draw.text((x-20,y),txt,(0,0,0))
drawtree(tree,jpeg='tree.jpg')


預測

def classify(observation,tree):
    if tree.results!=None:
        return tree.results
    else:
        v=observation[tree.col]
        branch=None
        if isinstance(v,int) or isinstance(v,float):
            if v>=tree.value: branch=tree.tb #走右子樹
            else: branch=tree.fb #走左子樹
        else:
            if v==tree.value: branch=tree.tb
            else: branch=tree.fb
        return classify(observation,branch)
classify(['google','France','yes',23],tree)

疑問:爲什麼對於數字的屬性結點,是大於等於走右子樹,而不是小於等於?

使用基尼指數

在這裏插入圖片描述

def giniimpurity(rows):
    total = len(rows)
    counts = uniquecounts(rows)
    imp = 0
    for k1 in counts.keys():
        p1 = float(counts[k1])/total
        imp+= p1*(1-p1)
    return imp

剪枝處理

剪枝是決策樹學習算法對付”過擬合“的主要手段,基本策略有:預剪枝”和“後剪枝”,這裏介紹後剪枝。
找到劃分爲葉子節點的屬性結點,判斷劃分前後的驗證集的精度是否有提高,設定一個閾值,提高的精度大於這個閾值就支持剪枝,否則合併這些分支。
下面利用熵最爲評判依據

def prune(tree,mingain):
    # 如果分支不是葉節點,則對其進行剪枝操作
    if tree.tb.results==None:
        prune(tree.tb,mingain)
    if tree.fb.results==None:
        prune(tree.fb,mingain)

    # 如果兩個自分支都是葉節點,則判斷他們是否需要合併
    if tree.tb.results!=None and tree.fb.results!=None:
        # 構建合併後的數據集
        tb,fb=[],[]
        for v,c in tree.tb.results.items():
            tb+=[[v]]*c
        for v,c in tree.fb.results.items():
            fb+=[[v]]*c

        # 檢查熵的減少情況
        delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)

        if delta<mingain:
            # 合併分支
            tree.tb,tree.fb=None,None
            tree.results=uniquecounts(tb+fb)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章