決策樹剪枝簡單python實現

決策樹是一種依託決策而建立起來的一種樹。在機器學習中,決策樹是一種預測模型,代表的是一種對象屬性與對象值之間的一種映射關係,每一個節點代表某個對象,樹中的每一個分叉路徑代表某個可能的屬性值,而每一個葉子節點則對應從根節點到該葉子節點所經歷的路徑所表示的對象的值。決策樹僅有單一輸出,如果有多個輸出,可以分別建立獨立的決策樹以處理不同的輸出。

ID3算法:ID3算法是決策樹的一種,是基於奧卡姆剃刀原理的,即用盡量用較少的東西做更多的事。ID3算法,即Iterative Dichotomiser 3,迭代二叉樹3代,是Ross Quinlan發明的一種決策樹算法,這個算法的基礎就是上面提到的奧卡姆剃刀原理,越是小型的決策樹越優於大的決策樹,儘管如此,也不總是生成最小的樹型結構,而是一個啓發式算法。在信息論中,期望信息越小,那麼信息增益就越大,從而純度就越高。ID3算法的核心思想就是以信息增益來度量屬性的選擇,選擇分裂後信息增益最大的屬性進行分裂。該算法採用自頂向下的貪婪搜索遍歷可能的決策空間。
信息熵,將其定義爲離散隨機事件出現的概率,一個系統越是有序,信息熵就越低,反之一個系統越是混亂,它的信息熵就越高。所以信息熵可以被認爲是系統有序化程度的一個度量。

基尼指數:在CART裏面劃分決策樹的條件是採用Gini Index,定義如下:gini(T)=1−sumnj=1p2j。其中,( p_j )是類j在T中的相對頻率,當類在T中是傾斜的時,gini(T)會最小。將T劃分爲T1(實例數爲N1)和T2(實例數爲N2)兩個子集後,劃分數據的Gini定義如下:ginisplit(T)=fracN1Ngini(T1)+fracN2Ngini(T2),然後選擇其中最小的(gini_{split}(T) )作爲結點劃分決策樹
具體實現
首先用函數calcShanno計算數據集的香農熵,給所有可能的分類創建字典 def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
# 給所有可能分類創建字典
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
# 以2爲底數計算香農熵
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt

# 對離散變量劃分數據集,取出該特徵取值爲value的所有樣本
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

對連續變量劃分數據集,direction規定劃分的方向, 決定是劃分出小於value的數據樣本還是大於value的數據樣本集

    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    bestSplitDict = {}
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        # 對連續型特徵進行處理
        if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
            # 產生n-1個候選劃分點
            sortfeatList = sorted(featList)
            splitList = []
            for j in range(len(sortfeatList) - 1):
                splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)

            bestSplitEntropy = 10000
            slen = len(splitList)
            # 求用第j個候選劃分點劃分時,得到的信息熵,並記錄最佳劃分點
            for j in range(slen):
                value = splitList[j]
                newEntropy = 0.0
                subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
                subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
                prob0 = len(subDataSet0) / float(len(dataSet))
                newEntropy += prob0 * calcShannonEnt(subDataSet0)
                prob1 = len(subDataSet1) / float(len(dataSet))
                newEntropy += prob1 * calcShannonEnt(subDataSet1)
                if newEntropy < bestSplitEntropy:
                    bestSplitEntropy = newEntropy
                    bestSplit = j
            # 用字典記錄當前特徵的最佳劃分點
            bestSplitDict[labels[i]] = splitList[bestSplit]
            infoGain = baseEntropy - bestSplitEntropy
        # 對離散型特徵進行處理
        else:
            uniqueVals = set(featList)
            newEntropy = 0.0
            # 計算該特徵下每種劃分的信息熵
            for value in uniqueVals:
                subDataSet = splitDataSet(dataSet, i, value)
                prob = len(subDataSet) / float(len(dataSet))
                newEntropy += prob * calcShannonEnt(subDataSet)
            infoGain = baseEntropy - newEntropy
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    # 若當前節點的最佳劃分特徵爲連續特徵,則將其以之前記錄的劃分點爲界進行二值化處理
    # 即是否小於等於bestSplitValue
    if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
        bestSplitValue = bestSplitDict[labels[bestFeature]]
        labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
        for i in range(shape(dataSet)[0]):
            if dataSet[i][bestFeature] <= bestSplitValue:
                dataSet[i][bestFeature] = 1
            else:
                dataSet[i][bestFeature] = 0
    return bestFeature
def chooseBestFeatureToSplit(dataSet, labels):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    bestSplitDict = {}
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        # 對連續型特徵進行處理
        if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
            # 產生n-1個候選劃分點
            sortfeatList = sorted(featList)
            splitList = []
            for j in range(len(sortfeatList) - 1):
                splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)

            bestSplitEntropy = 10000
            slen = len(splitList)
            # 求用第j個候選劃分點劃分時,得到的信息熵,並記錄最佳劃分點
            for j in range(slen):
                value = splitList[j]
                newEntropy = 0.0
                subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
                subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
                prob0 = len(subDataSet0) / float(len(dataSet))
                newEntropy += prob0 * calcShannonEnt(subDataSet0)
                prob1 = len(subDataSet1) / float(len(dataSet))
                newEntropy += prob1 * calcShannonEnt(subDataSet1)
                if newEntropy < bestSplitEntropy:
                    bestSplitEntropy = newEntropy
                    bestSplit = j
            # 用字典記錄當前特徵的最佳劃分點
            bestSplitDict[labels[i]] = splitList[bestSplit]
            infoGain = baseEntropy - bestSplitEntropy
        # 對離散型特徵進行處理
        else:
            uniqueVals = set(featList)
            newEntropy = 0.0
            # 計算該特徵下每種劃分的信息熵
            for value in uniqueVals:
                subDataSet = splitDataSet(dataSet, i, value)
                prob = len(subDataSet) / float(len(dataSet))
                newEntropy += prob * calcShannonEnt(subDataSet)
            infoGain = baseEntropy - newEntropy
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    # 若當前節點的最佳劃分特徵爲連續特徵,則將其以之前記錄的劃分點爲界進行二值化處理
    # 即是否小於等於bestSplitValue
    if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
        bestSplitValue = bestSplitDict[labels[bestFeature]]
        labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
        for i in range(shape(dataSet)[0]):
            if dataSet[i][bestFeature] <= bestSplitValue:
                dataSet[i][bestFeature] = 1
            else:
                dataSet[i][bestFeature] = 0
    return bestFeature
``def classify(inputTree, featLabels, testVec):
    firstStr = inputTree.keys()[0]
    if u'<=' in firstStr:
        featvalue = float(firstStr.split(u"<=")[1])
        featkey = firstStr.split(u"<=")[0]
        secondDict = inputTree[firstStr]
        featIndex = featLabels.index(featkey)
        if testVec[featIndex] <= featvalue:
            judge = 1
        else:
            judge = 0
        for key in secondDict.keys():
            if judge == int(key):
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = classify(secondDict[key], featLabels, testVec)
                else:
                    classLabel = secondDict[key]
    else:
        secondDict = inputTree[firstStr]
        featIndex = featLabels.index(firstStr)
        for key in secondDict.keys():
            if testVec[featIndex] == key:
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = classify(secondDict[key], featLabels, testVec)
                else:
                    classLabel = secondDict[key]
    return classLabel
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    return max(classCount)
def testing_feat(feat, train_data, test_data, labels):
    class_list = [example[-1] for example in train_data]
    bestFeatIndex = labels.index(feat)
    train_data = [example[bestFeatIndex] for example in train_data]
    test_data = [(example[bestFeatIndex], example[-1]) for example in test_data]
    all_feat = set(train_data)
    error = 0.0
    for value in all_feat:
        class_feat = [class_list[i] for i in range(len(class_list)) if train_data[i] == value]
        major = majorityCnt(class_feat)
        for data in test_data:
            if data[0] == value and data[1] != major:
                error += 1.0
    # print 'myTree %d' % error
    return error

測試

    error = 0.0
    for i in range(len(data_test)):
        if classify(myTree, labels, data_test[i]) != data_test[i][-1]:
            error += 1
    # print 'myTree %d' % error
    return float(error)
def testingMajor(major, data_test):
    error = 0.0
    for i in range(len(data_test)):
        if major != data_test[i][-1]:
            error += 1
    # print 'major %d' % error
    return float(error)

**遞歸產生決策樹**

```def createTree(dataSet,labels,data_full,labels_full,test_data,mode):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
       return classList[0]
    if len(dataSet[0])==1:
       return majorityCnt(classList)
    labels_copy = copy.deepcopy(labels)
    bestFeat=chooseBestFeatureToSplit(dataSet,labels)
    bestFeatLabel=labels[bestFeat]

    if mode == "unpro" or mode == "post":
        myTree = {bestFeatLabel: {}}
    elif mode == "prev":
        if testing_feat(bestFeatLabel, dataSet, test_data, labels_copy) < testingMajor(majorityCnt(classList),
                                                                                       test_data):
            myTree = {bestFeatLabel: {}}
        else:
            return majorityCnt(classList)
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)

    if type(dataSet[0][bestFeat]).__name__ == 'unicode':
        currentlabel = labels_full.index(labels[bestFeat])
        featValuesFull = [example[currentlabel] for example in data_full]
        uniqueValsFull = set(featValuesFull)

    del (labels[bestFeat])

    for value in uniqueVals:
        subLabels = labels[:]
        if type(dataSet[0][bestFeat]).__name__ == 'unicode':
            uniqueValsFull.remove(value)

        myTree[bestFeatLabel][value] = createTree(splitDataSet \
                                                      (dataSet, bestFeat, value), subLabels, data_full, labels_full,
                                                  splitDataSet \
                                                      (test_data, bestFeat, value), mode=mode)
    if type(dataSet[0][bestFeat]).__name__ == 'unicode':
        for value in uniqueValsFull:
            myTree[bestFeatLabel][value] = majorityCnt(classList)

    if mode == "post":
        if testing(myTree, test_data, labels_copy) > testingMajor(majorityCnt(classList), test_data):
            return majorityCnt(classList)
    return myTree








<div class="se-preview-section-delimiter"></div>

```**讀入數據**

```def load_data(file_name):
    with open(r"dd.csv", 'rb') as f:
      df = pd.read_csv(f,sep=",")
      print(df)
      train_data = df.values[:11, 1:].tolist()
    print(train_data)
    test_data = df.values[11:, 1:].tolist()
    labels = df.columns.values[1:-1].tolist()
    return train_data, test_data, labels





<div class="se-preview-section-delimiter"></div>

```測試並繪製樹圖
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="round4", color='red')  # 定義判斷結點形態
leafNode = dict(boxstyle="circle", color='grey')  # 定義葉結點形態
arrow_args = dict(arrowstyle="<-", color='blue')  # 定義箭頭


# 計算樹的葉子節點數量
def getNumLeafs(myTree):
    numLeafs = 0
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


# 計算樹的最大深度
def getTreeDepth(myTree):
    maxDepth = 0
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


# 畫節點
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
                            xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \
                            bbox=nodeType, arrowprops=arrow_args)


# 畫箭頭上的文字
def plotMidText(cntrPt, parentPt, txtString):
    lens = len(txtString)
    xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002
    yMid = (parentPt[1] + cntrPt[1]) / 2.0
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
            plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
    plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD


def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.x0ff = -0.5 / plotTree.totalW
    plotTree.y0ff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()
if __name__ == "__main__":
    train_data, test_data, labels = load_data("dd.csv")
    data_full = train_data[:]
    labels_full = labels[:]

    mode="post"
    mode = "prev"
    mode="post"
    myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)
    createPlot(myTree)
    print(json.dumps(myTree, ensure_ascii=False, indent=4))

選擇mode就可以分別得到三種樹圖

if __name__ == "__main__":
    train_data, test_data, labels = load_data("dd.csv")
    data_full = train_data[:]
    labels_full = labels[:]

    mode="post"
    mode = "prev"
    mode="post"
    myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)
    createPlot(myTree)
    print(json.dumps(myTree, ensure_ascii=False, indent=4))

選擇mode就可以分別得到三種樹圖
這裏寫圖片描述

這裏寫圖片描述

這裏寫圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章