'''
優點:計算複雜度不高,輸出結果易於理解,對中間值的缺失不敏感,可以處理不相關特徵數據
缺點:可能會產生過度匹配問題
適用數據類型:數值型和標稱型
信息增益:ID3
信息增率:C4.5
基尼指數:CART
'''
from math import log
import operator
#計算給定數據集的香農熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet) #計算數據集中實例的總數
labelCounts = {} #創建一個字典
for featVec in dataSet:
currentLabel = featVec[-1] #其鍵值爲數據集最後一列數值
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0 #如果當前鍵值不存在,則擴展字典並將當前鍵值加入字典
labelCounts[currentLabel] += 1 #每個鍵值都記錄了當前類別出現的次數
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries #統計所有類標籤發生的次數 ,計算類別出現的概率
shannonEnt -= prob * log(prob, 2) #以2爲底求對數,計算香農熵
return shannonEnt
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def splitDataSet(dataSet, axis, value):
#輸入參數:待劃分的數據集,劃分數據集的特徵,特徵的返回值
retDataSet = [] #防止原始數據集被修改,聲明一個新的列表對象
for featVec in dataSet: #遍歷數據集中的每個元素
if featVec[axis] == value: #抽取符合特徵的數據
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
#選擇最好的數據集劃分方式
'''
調用時數據需要滿足的要求:
數據必須是由相同數據長度的列表元素組成的列表,
數據的最後一列或每個實例的最後一個元素是類別標籤
'''
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet) #計算整個數據集的香農熵
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures): #遍歷數據集的所有特徵
featList = [example[i] for example in dataSet] #將數據集中所有第i個特徵值寫入新列表中
uniqueVals = set(featList) #將列表轉爲集合數據類型,快速得到列表中唯一元素值
newEntropy = 0.0
for value in uniqueVals: #遍歷當前特徵中所有的唯一屬性值
subDataSet = splitDataSet(dataSet, i, value) #對每個特徵劃分一次數據集
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet) #計算數據集的新熵值,並對所有唯一特徵值得到的熵求和
infoGain = baseEntropy - newEntropy #計算信息增益,熵的減少
if (infoGain > bestInfoGain): #比較所有特徵的信息增益
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
#相當於投票表決,採用多數表決的方法決定該葉節點的分類
classCount = {}
for vote in classList:
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
#創建樹,輸入參數爲數據集和標籤列表
classList = [example[-1] for example in dataSet] #創建包含數據集所有類標籤的列表
if classList.count(classList[0]) == len(classList): #類別完全相同則停止繼續劃分
return classList[0]
if len(dataSet[0]) == 1: #遍歷完所有特徵時返回出現次數最多的標籤
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) #選取最好的特徵劃分
bestFeatLabel = labels[bestFeat] #最好特徵對應的標籤
myTree = {bestFeatLabel:{}} #用於存儲樹的信息的字典變量
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues) #得到列表包含的所有屬性值
for value in uniqueVals:
subLabels = labels[:] #爲了不改變原始列表內容,複製類標籤,使用新變量代替原始列表
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
def classify(inputTree, featLabels, testVec):
firstSides = list(inputTree.keys())
firstStr = firstSides[0] #第一次劃分數據集的類標籤
secondDict = inputTree[firstStr] #第一次劃分數據集的類標籤所附帶的子節點取值
featIndex = featLabels.index(firstStr) #使用index方法查找當前列表中第一個匹配firstStr變量的元素
for key in secondDict.keys(): #遞歸遍歷整棵樹,比較testVec變量中的值與樹節點的值
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else: classLabel = secondDict[key] #如果到達葉子節點,則返回當前節點的分類標籤
return classLabel
def storeTree(inputTree, filename):
#使用模塊pickle序列化對象,序列化對象可以在磁盤上保存對象並在需要的時候讀取出來
import pickle
fw = open(filename,'w')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
決策樹可視化
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle = 'sawtooth', fc = '0.8') #定義文本框和箭頭格式
leafNode = dict(boxstyle = 'round4', fc = '0.8')
arrow_args = dict(arrowstyle = '<-')
def plotNode(nodeTxt, centerPt, parentPt, nodeType): #繪製帶箭頭的註解
createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction',
xytext = centerPt, textcoords = 'axes fraction',
va = 'center', ha = 'center', bbox = nodeType,arrowprops = arrow_args)
def getNumLeafs(myTree):
#獲取葉節點的數目
numLeafs = 0
firstSides = list(myTree.keys())
firstStr = firstSides[0] #第一次劃分數據集的類標籤
secondDict = myTree[firstStr] #第一次劃分數據集的類標籤所附帶的子節點取值
for key in secondDict.keys(): #遍歷整棵樹的所有子節點
if type(secondDict[key]).__name__== 'dict': #如果子節點是字典類型,則該節點是一個判斷結點,否則爲葉節點
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs += 1 #累計葉子節點的個數並返回該數值
return numLeafs
def getTreeDepth(myTree):
#獲取樹的層數
maxDepth = 0
firstSides = list(myTree.keys())
firstStr = firstSides[0] #第一次劃分數據集的類標籤
secondDict = myTree[firstStr] #第一次劃分數據集的類標籤所附帶的子節點取值
for key in secondDict.keys(): #統計遍歷過程中遇到判斷節點的個數
if type(secondDict[key]).__name__== 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
#預先存儲的樹的信息
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
def plotMidText(cntrPt, parentPt, txtString):
#計算父節點和子節點的中間位置,並在此處添加簡單的文本標籤信息
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) #計算樹的寬度
depth = getTreeDepth(myTree) #計算樹的深度
firstSides = list(myTree.keys())
firstStr = firstSides[0] #第一次劃分數據集的類標籤
#追蹤已經繪製的節點位置,以及放置下一個節點的恰當位置
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
#樹的寬度用於計算放置判斷節點的位置,原則是將它放在所有葉子節點的中間
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #由於是自頂向下繪製圖形,故依次遞減y座標
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': #如果節點是判斷節點則遞歸調用plotTree()函數
plotTree(secondDict[key], cntrPt, str(key))
else: #如果節點是葉子節點,則在圖形上畫出葉子節點
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #繪製了所有子節點後,增加y座標的偏移
def createPlot(inTree):
fig = plt.figure(1,facecolor = 'white')
fig.clf() #創建新圖形並清空繪圖區
axprops = dict(xticks = [], yticks = [])
createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
#使用樹的寬度與深度計算樹節點的擺放位置
plotTree.totalW = float(getNumLeafs(inTree)) #存儲樹的寬度
plotTree.totalD = float(getTreeDepth(inTree)) #存儲樹的深度
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree,(0.5,1.0), '')
plt.show()