《機器學習實戰》——在python中使用Matplotlib註解繪製樹形圖

# encoding=utf-8
#使用文本註解繪製樹形圖
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
#上面三行代碼定義文本框和箭頭格式
#定義決策樹決策結果的屬性,用字典來定義,也可寫作 decisionNode={boxstyle:'sawtooth',fc:'0.8'}
#其中boxstyle表示文本框類型,sawtooth是波浪型的,fc指的是註釋框顏色的深度
#arrowstyle表示箭頭的樣式

def plotNode(nodeTxt, centerPt, parentPt, nodeType):#該函數執行了實際的繪圖功能
#nodeTxt指要顯示的文本,centerPt指的是文本中心點,parentPt指向文本中心的點
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )


#獲取葉節點的數目
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=myTree.keys()[0]#字典的第一個鍵,也就是樹的第一個節點
    secondDict=myTree[firstStr]#這個鍵所對應的值,即該節點的所有子樹。
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#測試節點的數據類型是否爲字典
            numLeafs+=getNumLeafs(secondDict[key])#遞歸,如果是字典的話,繼續遍歷
        else:numLeafs+=1#如果不是字典型的話,說明是葉節點,則葉節點的數目加1
    return numLeafs
#獲取樹的層數
def getTreeDepth(myTree):#和上面的函數結果幾乎一致
    maxDepth=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth=1+getTreeDepth(secondDict[key])#遞歸
        else:thisDepth=1#一旦到達葉子節點將從遞歸調用中返回,並將計算深度加1
        if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth
#構建兩棵樹,用來測試
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}},3: 'maybe'}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]
#可視化
def plotMidText(cntrPt,parentPt,txtString):#計算父節點和子節點的中間位置,並在父子節點間填充文本信息
    # cntrPt文本中心點   parentPt 指向文本中心的點
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)#調用getNumLeafs()函數計算葉子節點數目(寬度)
    depth=getTreeDepth(myTree)#調用getTreeDepth(),計算樹的層數(深度)
    firstStr=myTree.keys()[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)#plotTree.totalW表示樹的深度
    plotMidText(cntrPt,parentPt,nodeTxt)#調用 plotMidText()函數,填充信息nodeTxt
    plotNode(firstStr,cntrPt,parentPt,decisionNode)#調用plotNode()函數,繪製帶箭頭的註解
    secondDict=myTree[firstStr]
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    #因從上往下畫,所以需要依次遞減y的座標值,plotTree.totalD表示存儲樹的深度
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))#遞歸
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalID#h繪製完所有子節點後,增加全局變量Y的偏移。

def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')#繪圖區域爲白色
    fig.clf()#清空繪圖區
    axprops = dict(xticks=[], yticks=[])#定義橫縱座標軸
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    #由全局變量createPlot.ax1定義一個繪圖區,111表示一行一列的第一個,frameon表示邊框,**axprops不顯示刻度
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

發佈了30 篇原創文章 · 獲贊 69 · 訪問量 15萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章