利用 Python 結合 Matplotlib 繪製樹圖形
參考:https://blog.csdn.net/maotianyi941005/article/details/82349032
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def getNumLeafs(myTree):
# 初始化結點數
numLeafs=0
firstSides = list(myTree.keys())
# 找到輸入的第一個元素,第一個關鍵詞爲劃分數據集類別的標籤
firstStr = firstSides[0]
secondDict = myTree[firstStr]
# 測試數據是否爲字典形式
for key in secondDict.keys():
# type判斷子結點是否爲字典類型
if type(secondDict[key]).__name__=='dict':
numLeafs+=getNumLeafs(secondDict[key])
#若子節點也爲字典,則也是判斷結點,需要遞歸獲取num
else: numLeafs+=1
# 返回整棵樹的結點數
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstSides = list(myTree.keys())
firstStr = firstSides[0]
# 獲取劃分類別的標籤
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
# 計算樹的寬度 totalW
numLeafs = getNumLeafs(myTree)
# 計算樹的高度 存儲在totalD
depth = getTreeDepth(myTree)
firstSides = list(myTree.keys())
# firstStr = myTree.keys()[0] 續作修改
# 找到輸入的第一個元素
firstStr = firstSides[0]
# 按照葉子結點個數劃分x軸
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
# 標註子結點屬性值
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
# y方向上的擺放位置,自上而下繪製,遞減y值
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
# 判斷是否爲字典,不是則爲葉子結點
if type(secondDict[key]).__name__=='dict':
# 遞歸查找
plotTree(secondDict[key],cntrPt,str(key))
# 到達葉子結點
else:
# x方向計算結點座標
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
# 添加文本信息
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
# 下次重新調用時恢復y
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
# 主函數
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
# 在繪圖區上繪製兩個代表不同類型的樹節點
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers':
{0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
data = getDataSet('E:/bigdata/watermelon_3a.csv')
label = data.loc[:,'label']
dataSet = data.loc[ : , 'color':'sugar_ratio']
dataSet = dataSet.values.tolist()
dataSet
labels = label.values.tolist()
labels
mytree = dtm.createTree(dataSet,labels)
print(mytree)
createPlot(mytree)