第三章 决策树 3.2 使用 Matplotlib 注解绘制树形图

上节学习了如何从数据集中创建树,但是字典的表示形式非常不易于理解,而且直接绘制图形也比较困难,这一节我们将使用 Matplotlib库 来创建树形图。


3.2.1 Matplotlib注解

Matplotlib提供了一个注解工具 annotations,可以在数据图形上添加文本注释。

创建一个文件,命名为 treePlotter.py ,然后输入:

# -*- coding:utf-8 -*-
import matplotlib.pyplot as plt


# 定义文本框和箭头格式
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8") # fc 应该是颜色深浅
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    # centerPt 箭头指向座标, parentPt 箭头终点座标
    createPlot.ax1.annotate(nodeTxt, xy = parentPt,\
    xycoords = 'axes fraction',\
    xytext = centerPt, textcoords = 'axes fraction',\
    va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)


def createPlot():
    fig = plt.figure(1, facecolor = 'white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon = False)
    plotNode(U'决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode) # U 这里指的是 utf 编码
    plotNode(U'叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

建立一个运行文件 run_treePlotter.py ,输入:

# run_treePlotter.py
import treePlotter
print '>>> treePlotter.createPlot()'
treePlotter.createPlot()
结果如下:



看起来很不错的图片。这就是绘制树节点的方法。有个问题在于,字是乱码,不知道怎么解决。

所以决定改为英文。

    plotNode(U'decisionNode', (0.5, 0.1), (0.1, 0.5), decisionNode) # U 这里指的是 utf 编码
    plotNode(U'leafNode', (0.8, 0.1), (0.3, 0.8), leafNode)


=========================================================================

3.2.2 构造注解树

我们虽然有 xy 座标,但是如何放置所有的树节点却是个问题。我们必须知道有多少个叶节点,以便可以正确确定 x 轴的长度;我们还需要知道树有多少层,以便可以正确确定 y 轴的高度。


这里我们定义两个新函数,来获取叶节点的数目和树的层数。

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0] # dict.keys() 返回字典的 keys
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        # 利用 type() 函数测试节点的数据类型是否为字典
        if type(secondDict[key]).__name__ == 'dict': # 如果模块是被导入,__name__的值为模块名字
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs += 1
    return numLeafs


def getTreeDepth(myTree): # 计算遍历过程中遇到判断节点的个数。
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

然后,在 run_treePlotter.py 添加:

# run_treePlotter.py
import treePlotter
print '>>> treePlotter.createPlot()'
treePlotter.createPlot()
print '***************************************\n'
reload(treePlotter)
print '>>> treePlotter.retrieveTree(1)'
print treePlotter.retrieveTree(1)
print '>>> myTree = treePlotter.retrieveTree(0)'
print '>>> treePlotter.getNumLeafs(myTree)'
myTree = treePlotter.retrieveTree(0)
print treePlotter.getNumLeafs(myTree)
print '>>> treePlotter.getTreeDepth(myTree)'
print treePlotter.getTreeDepth(myTree)

结果是:





当然我们还是没完整把图画出来。

现在开始添加绘图代码:

def plotMidText(cntrPt, parentPt, txtString): # 在父子节点间填充文本信息 
    # 计算父节点和子节点的中间位置
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt): # 计算树的宽与高
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,\
                            plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt) # 标记子节点属性值
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD # 减少 y 偏移
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),\
                cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

然后,改动之前的 createPlot(inTree) 函数为:

def createPlot(inTree):
    fig = plt.figure(1, facecolor = 'white')
    fig.clf()
    axprops = dict(xticks = [], yticks = [])
    createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree)) # 储存树的宽度
    plotTree.totalD = float(getTreeDepth(inTree)) # 储存树的深度
    plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()
准备运行代码, run_treePlotter.py 改为:

# run_treePlotter.py
import treePlotter
"""
print '>>> treePlotter.createPlot()'
treePlotter.createPlot()
print '***************************************\n'
reload(treePlotter)
print '>>> treePlotter.retrieveTree(1)'
print treePlotter.retrieveTree(1)
print '>>> myTree = treePlotter.retrieveTree(0)'
print '>>> treePlotter.getNumLeafs(myTree)'
myTree = treePlotter.retrieveTree(0)
print treePlotter.getNumLeafs(myTree)
print '>>> treePlotter.getTreeDepth(myTree)'
print treePlotter.getTreeDepth(myTree)
"""
print '***************************************\n'
reload(treePlotter)
print '>>> myTree = treePlotter.retrieveTree(0)'
myTree = treePlotter.retrieveTree(0)
print '>>> treePlotter.createPlot(myTree)'
print treePlotter.createPlot(myTree)
运行后图片为:



没有座标轴标签,我们要在运行文件里面添加一些命令,重新绘制树形图。

# run_treePlotter.py
import treePlotter

print '>>> myTree = treePlotter.retrieveTree(0)'
myTree = treePlotter.retrieveTree(0)
print ">>> myTree['no surfacing'][3] = maybe" # 增加分支
myTree['no surfacing'][3] = 'maybe'
print '>>> myTree'
print myTree
print '>>> treePlotter.createPlot(myTree)'
print treePlotter.createPlot(myTree)
运行结果:








發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章