決策樹(一):ID3算法

1.決策樹的基本原理與僞代碼

決策樹算法,是一種監督學習的分類算法,可細分爲ID3、C4.5、CART等三種算法,前兩種適用於標稱型數據,後一種適用於數值型數據。

1.1決策樹的基本原理:

所謂決策樹,即根據樣本數據集的不同特徵不斷對數據集進行劃分,劃分的最終結果構成一棵樹。

該算法的難點在於:在衆多特徵中,最先選擇哪一個特徵對數據集進行劃分?

ID3算法採用信息增益;C4.5算法採用信息增益率;CART算法採用基尼係數

本文主要介紹ID3算法,即以數據劃分前後的信息增益爲指標進行特徵選擇。

香農熵的計算公式:

信息增益=劃分前的香農熵—劃分後的條件熵

1.2決策樹算法的僞代碼:

createbranch():

if 所有樣本數據的標籤均一致

   返回該標籤

else 尋找劃分數據集的最好特徵

        劃分數據集

        創建分支節點

            for 每個劃分後的子集

                 調用函數createbranch並增加返回結果到分支節點中

return 分支節點

2.決策樹算法的優缺點

優點:形象直觀,易於理解,複雜度不高,可將分類器存儲在硬盤上,不用每次都重新學習

缺點:容易過擬合,需要剪枝(本文先不討論)

3.該算法的PYTHON語言實現

3.1決策樹分類算法的主體編程架構

一個完整的決策樹分類算法主要由以下幾塊構成:

  • 構建決策樹,重點在於選擇劃分數據集的最好特徵
  • 存儲決策樹並使用該決策樹對測試數據進行分類
  • 使用matplotlib繪製決策樹

3.2決策樹-ID3算法的PYTHON代碼

3.2.1構建決策樹

這一塊主要由以下幾個子函數構成:

  • 計算香農熵
  • 根據特徵值劃分數據集
  • 選擇劃分數據集的最好特徵
  • (主體函數)構建樹

首先,來計算香農熵:

#計算香農熵
def calcShannonEnt(dataSet):
    # 數據集的數據個數
    numEntries = len(dataSet)
    # 建立一個標籤字典,鍵值是每一行的標籤,對應值是該標籤出現的次數
    labelCounts = {}
    for featVec in dataSet:     # 遍歷數據集中的每一行
        currentLabel = featVec[-1]  # currentlable存放的是當前一行的標籤
        # 如果當前標籤不在之前的標籤字典裏,將標籤字典裏當前標籤對應的值賦0
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1  #統計當前標籤出現的次數 
        
    # 計算當前數據集的香農熵並返回
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries #prob爲當前標籤在所有標籤裏出現的概率
        shannonEnt -= prob * math.log(prob,2)  
    
return shannonEnt

其次,構建一個子函數,其功能是根據給定的特徵值來劃分數據集:

# ==============================================
# 輸入:
#        dataSet: 訓練集文件名(含路徑)
#        axis: 選定的特徵所在列數
#        value: 選定的特徵值
# 輸出:
#        retDataSet: 劃分後的子列表
# ==============================================
#函數功能:找出所有行中第axis個元素值爲value的行,去掉該元素,返回對應行矩陣
def splitDataSet(dataSet, axis, value):
    
    retDataSet = [] # 存放劃分後的子列表
    for featVec in dataSet:     # 逐行遍歷數據集
        if featVec[axis] == value:      # 如果當前行第axis列的特徵值等於value
            reducedFeatVec = featVec[:axis]  # 抽取掉數據集中的目標特徵值列
            reducedFeatVec.extend(featVec[axis+1:])
            # 將抽取後的數據加入到劃分結果列表中
            retDataSet.append(reducedFeatVec)         
    return retDataSet

注意:

  1. a=b[:axis],該語句提取出b列表的前0到(axis-1)列賦值給a,a.extend(b[axis+1:]),該語句在a列表後面加上b列表的(axis+1)到最後列。因此該兩句聯合起來就是刪除b列表中axis列的值。
  2. append與extend的區別:

        extend是將兩個列表相連,append是將新列表整體作爲一個對象一個元素添加到舊列表中

        list_extend = ['a', 'b', 'c'],list_extend.extend(['d', 'e', 'f']),print("list_extend:%s" %list_extend)
        # 輸出結果:list_extend:['a', 'b', 'c', 'd', 'e', 'f']
        list_append = ['a', 'b', 'c'],list_append.append(['d', 'e', 'f']),print("list_append:%s" %list_append)
        # 輸出結果:list_append:['a', 'b', 'c', ['d', 'e', 'f']]

然後,選擇最佳劃分特徵:

# ===============================================
# 輸入:
#        dataSet: 數據集
# 輸出:
#        bestFeature: 和原數據集熵差最大劃分對應的特徵的列號
# ===============================================
#選擇最佳的劃分特徵
def chooseBestFeatureToSplit(dataSet):
       
    numFeatures = len(dataSet[0]) - 1 #統計特徵總數
    baseEntropy = calcShannonEnt(dataSet) # 劃分前數據集的香農熵
    bestInfoGain = 0.0 # 暫存最大信息增益
    bestFeature = -1  # 暫存最大信息增益對應的特徵列號

    for i in range(numFeatures):    # 逐列遍歷數據集
        featList = [example[i] for example in dataSet]  # 獲取該列所有特徵值
        # 將特徵列featList的值去重並保存到集合uniqueVals
        uniqueVals = set(featList)         
        newEntropy = 0.0 #暫存劃分後數據集的香農熵
        # 計算該特徵劃分下所有劃分子集的香農熵,並求和。
        for value in uniqueVals:    # 遍歷該特徵列所有特徵值   
            subDataSet = splitDataSet(dataSet, i, value) #返回以第i個特徵的value值劃分後的子集
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        
        # 保存所有劃分法中,和原數據集熵差最大劃分對應的特徵的列號。
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
            print(bestFeature)
    return bestFeature

注意:

  1. 語句featList = [example[i] for example in dataSet]作用爲:
    將dataSet中的數據先按行依次放入example中,然後取得example中的example[i]元素,放入列表featList中
  2. set()函數的作用是,去掉列表中的重複元素。

最後,創建決策樹:

那麼如何用這些多層次的劃分子集搭建出一個樹結構呢?這部分更多涉及到編程技巧,某種程度上來說,就是用Python實現樹的問題。在Python中,可以用字典來具體實現樹:字典的鍵存放節點信息,值存放分支及子樹/葉子節點信息

# ===============================================
# 輸入:
#        dataSet: 數據集
#        labels: 標籤集
# 輸出:
#        myTree: 生成的決策樹
# ===============================================
#創建決策樹
def createTree(dataSet,labels):
    
    classList = [example[-1] for example in dataSet] # 獲得類標籤列表

    # 遞歸終止條件一:如果數據集內所有分類標籤均相同
    if classList.count(classList[0]) == len(classList): 
        return classList[0]
    
    # 遞歸終止條件二:如果所有特徵都劃分完畢
    if len(dataSet[0]) == 1:
        # 將它們都歸爲一類並返回
        return majorityCnt(classList)
    
    # 選擇最佳劃分特徵
    bestFeat = chooseBestFeatureToSplit(dataSet)
    # 最佳劃分對應的劃分標籤。注意不是分類標籤
    bestFeatLabel = labels[bestFeat]
    # 構建字典空樹
    myTree = {bestFeatLabel:{}}
    # 從劃分標籤列表中刪掉劃分後的元素
    del(labels[bestFeat])
    # 獲取最佳劃分對應特徵的所有特徵值
    featValues = [example[bestFeat] for example in dataSet]
    # 對特徵值列表featValues唯一化,結果存於uniqueVals。
    uniqueVals = set(featValues)
    
    for value in uniqueVals:    # 逐行遍歷特徵值集合
        # 保存所有劃分標籤信息並將其夥同劃分後的數據集傳遞進下一次遞歸
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
        
    return myTree

3.2.2儲存決策樹並進行分類

在機器學習中,我們常常需要把訓練好的模型存儲起來,這樣在進行決策時直接將模型讀出,而不需要重新訓練模型,這樣就大大節約了時間。Python提供的pickle模塊就很好地解決了這個問題,它可以序列化對象並保存到磁盤中,並在需要的時候讀取出來,任何對象都可以執行序列化操作。

# ======================
# 輸入:
#        myTree:    決策樹
# 輸出:
#        決策樹文件
# ======================
def storeTree(inputTree,filename):   #'保存決策樹'
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()
    
# ========================
# 輸入:
#        filename:    決策樹文件名
# 輸出:
#        pickle.load(fr):    決策樹
# ========================    
def grabTree(filename):   #'打開決策樹'
    import pickle
    fr = open(filename)
    return pickle.load(fr)

 注意:

  • Pickle模塊中最常用的函數爲:

(1)pickle.dump(obj, file, [,protocol])

        函數的功能:將obj對象序列化存入已經打開的file中。

obj:想要序列化的obj對象;file:文件名稱;protocol:序列化使用的協議。如果該項省略,則默認爲0。如果爲負值或HIGHEST_PROTOCOL,則使用最高的協議版本。

(2)pickle.load(file)

        函數的功能:將file中的對象序列化讀出。file:文件名稱

 

將構造好的決策樹用於測試集數據的分類:

# ========================
# 輸入:
#        inputTree:    決策樹文件名
#        featLabels:    分類標籤集(特徵名稱集合)
#        testVec:        待分類的測試數據
# 輸出:
#        classLabel:    分類結果
# ======================== 
def classify(inputTree,featLabels,testVec):
    firstSides=List(inputTree.keys())
    firstStr=firstList[0]
    secondDict=inputTree[firstStr]
    featIndex=featLabels.index(firstStr)   # 找到當前分類標籤在分類標籤集中的下標
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            if type(secondDIct[key])._name_=='dict': #判斷是否到達葉子節點
                classLabel=classfy(secondDict[key],featLabels,testVec) #不是葉子節點,則繼續遍歷決策樹
            else: classLabel=secondDict[key] #若到達葉子節點,返回當前節點的分類標籤
    return classLabel

注意:

  • type()函數的作用爲返回該對象的數據類型:type(object) ;type(name, bases, dict)
  • #書上使用的是python2,代碼如下

    firstStr = myTree.keys()[0]

    #我使用的是python3,字典不能直接提取索引,就會報以上語句錯誤

    firstSides = list(myTree.keys())

    firstStr = firstSides[0]

    解決辦法就是先轉換成list,再把需要的索引提取出來。

3.2.3 使用matplotlib繪製決策樹

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

# ===============================================
# 輸入:
#        myTree: 決策樹
# 輸出:
#        numLeafs: 決策樹的葉子數
# ===============================================
def getNumLeafs(myTree):
    '計算決策樹的葉子數'
    
    # 葉子數
    numLeafs = 0
    # 節點信息
    firstSides=list(myTree.keys())
    firstStr=firstSides[0]
    # 分支信息
    secondDict = myTree[firstStr]
    
    for key in secondDict.keys():   # 遍歷所有分支
        # 子樹分支則遞歸計算
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        # 葉子分支則葉子數+1
        else:   numLeafs +=1
        
    return numLeafs

# ===============================================
# 輸入:
#        myTree: 決策樹
# 輸出:
#        maxDepth: 決策樹的深度
# ===============================================
def getTreeDepth(myTree):
    '計算決策樹的深度'
    
    # 最大深度
    maxDepth = 0
    # 節點信息
    firstSides=list(myTree.keys())
    firstStr=firstSides[0]
    # 分支信息
    secondDict = myTree[firstStr]
    
    for key in secondDict.keys():   # 遍歷所有分支
        # 子樹分支則遞歸計算
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        # 葉子分支則葉子數+1
        else:   thisDepth = 1
        
        # 更新最大深度
        if thisDepth > maxDepth: maxDepth = thisDepth
        
    return maxDepth    
# ==================================================
# 輸入:
#        nodeTxt:     終端節點顯示內容
#        centerPt:    終端節點座標
#        parentPt:    起始節點座標
#        nodeType:    終端節點樣式
# 輸出:
#        在圖形界面中顯示輸入參數指定樣式的線段(終端帶節點)
# ==================================================
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    '畫線(末端帶一個點)'
        
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

# =================================================================
# 輸入:
#        cntrPt:      終端節點座標
#        parentPt:    起始節點座標
#        txtString:   待顯示文本內容
# 輸出:
#        在圖形界面指定位置(cntrPt和parentPt中間)顯示文本內容(txtString)
# =================================================================
def plotMidText(cntrPt, parentPt, txtString):
    '在指定位置添加文本'
    
    # 中間位置座標
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

# ===================================
# 輸入:
#        myTree:    決策樹
#        parentPt:  根節點座標
#        nodeTxt:   根節點座標信息
# 輸出:
#        在圖形界面繪製決策樹
# ===================================
def plotTree(myTree, parentPt, nodeTxt):
    '繪製決策樹'
    
    # 當前樹的葉子數
    numLeafs = getNumLeafs(myTree)
    # 當前樹的節點信息
    firstSides=list(myTree.keys())
    firstStr=firstSides[0]
    # 定位第一棵子樹的位置(這是蛋疼的一部分)
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    
    # 繪製當前節點到子樹節點(含子樹節點)的信息
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    
    # 獲取子樹信息
    secondDict = myTree[firstStr]
    # 開始繪製子樹,縱座標-1。        
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
      
    for key in secondDict.keys():   # 遍歷所有分支
        # 子樹分支則遞歸
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        # 葉子分支則直接繪製
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
     
    # 子樹繪製完畢,縱座標+1。
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

# ==============================
# 輸入:
#        myTree:    決策樹
# 輸出:
#        在圖形界面顯示決策樹
# ==============================
def createPlot(inTree):
    '顯示決策樹'
    
    # 創建新的圖像並清空 - 無橫縱座標
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    
    # 樹的總寬度 高度
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    
    # 當前繪製節點的座標
    plotTree.xOff = -0.5/plotTree.totalW; 
    plotTree.yOff = 1.0;
    
    # 繪製決策樹
    plotTree(inTree, (0.5,1.0), '')
    
    plt.show()

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章