機器學習 | 決策樹

由於近期學業繁重QAQ,所以我就不說廢話了,直接上代碼~

運行效果

圖片描述


代碼

from math import log
import operator
import matplotlib.pyplot as plt

#定義文本框和箭頭格式
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")

#畫樹

#使用文本註解繪製樹節點
#繪製帶箭頭的註解
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,
    xycoords='axes fraction',
    xytext=centerPt,textcoords='axes fraction',
    va="center",ha="center",bbox=nodeType,
    arrowprops=arrow_args)
    
#在父子節點間填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)
    
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,
    plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),
            cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),
            cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
    
    
def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getNumLeafs(inTree))
    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

    
    


#創建數據集
def createDataSet():
    dataSet=[[1,1,'yes'],
            [1,1,'yes'],
            [1,0,'no'],
            [0,1,'no'],
            [0,1,'no']]
    labels=['no surfacing','flippers']
    return dataSet,labels

#計算給定數據的香農熵
#熵值越高,混合的數據越多,越無序
#我們可以在數據集中添加更多的分類
def calcShannonEnt(dataSet):
    numEntries=len(dataSet)
    #數據字典,鍵值爲最後一列的數值"yes"or"no"
    labelCounts={}
    for featVec in dataSet:
        #爲所有可能分類創建字典
        #"yes"or"no"
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt=0.0
    for key in labelCounts:
        prob=float(labelCounts[key])/numEntries
        #以2爲㡳求對數
        shannonEnt-=prob*log(prob,2)
    return shannonEnt
    
    
#按照給定特徵劃分數據集    
#輸入的參數爲:待劃分的數據集,
#劃分數據集的特徵(第幾列),
#特徵的返回值(這一列的值爲多少)
#返回的是符合這一列的值的每一行,
#並且將這一列的數據去掉了
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    #遍歷整個數據集
    #featVec:[1, 1, 'yes']
    for featVec in dataSet:
        #print('featVec:')
        #print(featVec)
        #抽取其中符合特徵的
        #featVec[axis]表示[1, 1, 'yes']中的第axis+1個
        if featVec[axis]==value:
            #保存這一列前面的數據
            reducedFeatVec=featVec[:axis]
            #print('reducedFeatVec:')
            #print(reducedFeatVec)
            #保存這一列後面的數據
            reducedFeatVec.extend(featVec[axis+1:])
            #print('reducedFeatVec:')
            #print(reducedFeatVec)
            retDataSet.append(reducedFeatVec)
    #print('retDataSet:')
    #print(retDataSet)
    return retDataSet
        

#選擇最好的數據集劃分方式
def chooseBestFeatureToSplit(dataSet):
    #numFeatures:2
    numFeatures=len(dataSet[0])-1
    #計算香農熵
    baseEntropy=calcShannonEnt(dataSet)
    bestInfoGain=0.0
    bestFeature=-1
    #i:0,1
    for i in range(numFeatures):
        #取出dataSet的第i列
        featList=[example[i] for example in dataSet]
        #print('featList:')
        #print(featList)
        #弄成一個set,去掉其中相同的元素
        uniqueVals=set(featList)
        #print('uniqueVals:')
        #print(uniqueVals)
        newEntropy=0.0
        for value in uniqueVals:
            #按照第i列,值爲value的去劃分
            subDataSet=splitDataSet(dataSet,i,value)
            prob=len(subDataSet)/float(len(dataSet))
            #計算劃分後的熵值
            newEntropy+=prob*calcShannonEnt(subDataSet)
        infoGain=baseEntropy-newEntropy
        #判斷是否更優
        if(infoGain>bestInfoGain):
            bestInfoGain=infoGain
            bestFeature=i
    #返回劃分的最優類別
    #表示按照第i列去劃分
    return bestFeature

#傳入的是分類名稱的列表    
#返回出現次數最多的分類的名稱
def majorityCnt(classList):
    #創建字典,鍵值爲classList中唯一值
    #字典的值爲classList中每隔標籤出現的頻率
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    #按照字典值的順序從大到小排序
    sortedClassCount=sorted(classCount,iteritems(),
    key=operator.itemgetter(1),reverse=True)
    #返回出現次數最多的分類的名稱
    return sortedClassCount[0][0]

#創建樹    
#傳入參數爲數據集與標籤列表
def createTree(dataSet,labels):
    #得到分類名稱的標籤"yes"or"no"
    #['yes', 'yes', 'no', 'no', 'no']
    classList=[example[-1] for example in dataSet]
    #print('classList:')
    #print(classList)
    #遞歸結束的第一個條件
    #所有的類標籤完全相同
    if classList.count(classList[0])==len(classList):
        return classList[0]
    #遞歸結束的第二個條件
    #使用完了所有的特徵,仍然不能將數
    #據集劃分成僅包含唯一類別的分組
    #此時無法簡單地返回唯一的類標籤,
    #直接返回出現次數最多的類標籤
    if len(dataSet[0])==1:
        return majorityCnt(classList)
        
    #bestFeat是最好的劃分方式對應的列的下標    
    bestFeat=chooseBestFeatureToSplit(dataSet)
    #labels中這一列信息對應的類別名稱
    bestFeatLabel=labels[bestFeat]
    #樹
    myTree={bestFeatLabel:{}}
    #將labels中的這一類別delete
    del(labels[bestFeat])
    #這一類別對應的列的值
    featValues=[example[bestFeat] for example in dataSet]
    #print('featValues:')
    #print(featValues)
    #set 去掉列中相同的值
    uniqueVals=set(featValues)
    for value in uniqueVals:
        #去掉最優類別後剩下的類別
        subLabels=labels[:]
        #print('subLabels:')
        #print(subLabels)
        #print('bestFeatLabel:')
        #print(bestFeatLabel)
        #print(value)
        #myTree['no surfacing'][0]
        #myTree['no surfacing'][1]
        #......
        myTree[bestFeatLabel][value]=createTree(
        #按照第bestFeat列,值爲value的去劃分
        splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree
    
#獲取葉節點的數目
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=list(myTree.keys())[0]
    secondDir=myTree[firstStr]
    for key in secondDir.keys():
        #子節點爲字典類型,則該結點也是一個判斷結點
        #需要遞歸調用getNumLeafs函數
        if type(secondDir[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDir[key])
        #該結點爲葉子節點,葉子數+1
        else:
            numLeafs+=1
    return numLeafs
    
#獲取樹的層數
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else:
            thisDepth=1
        if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth

        
        
    
def main():
    dataSet,labels=createDataSet()
    chooseBestFeatureToSplit(dataSet)
    #{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
    myTree=createTree(dataSet,labels)
    print('myTree:')
    print(myTree)
    createPlot(myTree)
    #i=getNumLeafs(myTree)
    #print(i)
    #i=getTreeDepth(myTree)
    #print(i)
    #i=chooseBestFeatureToSplit(dataSet)
    #print(i)
    #shannonEnt=calcShannonEnt(dataSet)
    #print(shannonEnt)
    #增加一個類別後再測試信息熵,發現熵值增大
    #dataSet[0][-1]='maybe'
    #shannonEnt=calcShannonEnt(dataSet)
    #print(shannonEnt)
    #retDataSet=splitDataSet(dataSet,0,1)
    #print('retDataSet:')
    #print(retDataSet)
    #retDataSet=splitDataSet(dataSet,0,0)
    #print('retDataSet:')
    #print(retDataSet)
    
    
    
if __name__=='__main__':
    main()
    
    
    

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章