決策樹算法

ID3決策樹算法類似算法流程圖。



決策樹算法

優點:計算複雜度不高,輸出結果易於理解,對中間值的缺失不敏感,可以處理不相關特徵數據。

缺點:可能會產生過度匹配問題。

適用數據類型:數值型和標稱型

基於Python的實現代碼:

1)準備子函數

  1. # -*- coding: cp936 -*-  
  2.   
  3. from math import log  
  4. import operator  
  5.   
  6. def createDataSet():#創建數據集  
  7.     dataSet = [[11'yes'],  
  8.                [11'yes'],  
  9.                [10'no'],  
  10.                [01'no'],  
  11.                [01'no']]  
  12.     labels = ['no surfacing','flippers']  
  13.     #change to discrete values  
  14.     return dataSet, labels  
  15.   
  16. def calcShannonEnt(dataSet):  
  17.     numEntries = len(dataSet)     #計算數據集的長度  
  18.     labelCounts = {}              #定義一個label字典,統計每個label出現的次數,鍵值爲label,值爲對應label出現的次數  
  19.     for featVec in dataSet:       #the the number of unique elements and their occurance  
  20.         currentLabel = featVec[-1]#數據集每個元素都是一個列表,每個元素列表的最後一列爲label  
  21.         if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 #判斷當前label是否已經存在字典鍵值列表中,沒有存在的話,將當前label加入字典,並設置對應值爲0  
  22.         labelCounts[currentLabel] += 1                                           #否則,當前label出現次數累加  
  23.     shannonEnt = 0.0  
  24.     for key in labelCounts:  
  25.         prob = float(labelCounts[key])/numEntries #計算每個label出現的概率  
  26.         shannonEnt -= prob * log(prob,2)          #計算數據集的(香農信息熵)信息熵,其中log base 2  
  27.     return shannonEnt  
  28.       
  29. def splitDataSet(dataSet, axis, value):#按照給定特徵劃分數據集:待劃分數據集,劃分數據集的特徵,需要返回的特徵值  
  30.     retDataSet = []                    #定義一個空列表,即:子數據集  
  31.     for featVec in dataSet:  
  32.         if featVec[axis] == value:                  #判斷待劃分數據集中元素列表指定位置的特徵是否與需要返回的特徵值匹配  
  33.             reducedFeatVec = featVec[:axis]         #chop out axis used for splitting  
  34.             reducedFeatVec.extend(featVec[axis+1:]) #獲取待劃分數據集中元素列表的子元素列表(已經裁剪掉指定數據集的特徵)  
  35.             retDataSet.append(reducedFeatVec)       #添加獲取的子元素列表到子數據集中  
  36.     return retDataSet  
  37.       
  38. def chooseBestFeatureToSplit(dataSet):     #選擇最好的數據集劃分方式--以不同特徵劃分子數據集的信息熵增益(或者數據集信息熵減少)大小爲依據!  
  39.     numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels  
  40.     baseEntropy = calcShannonEnt(dataSet)  #計算整個數據集的信息熵  
  41.     bestInfoGain = 0.0; bestFeature = -1  
  42.     for i in range(numFeatures):                      #iterate over all the features  
  43.         featList = [example[i] for example in dataSet]#create a list of all the examples of this feature 運用到列表推導式  
  44.         uniqueVals = set(featList)                    #get a set of unique values  
  45.         newEntropy = 0.0  
  46.         for value in uniqueVals:  
  47.             subDataSet = splitDataSet(dataSet, i, value)  
  48.             prob = len(subDataSet)/float(len(dataSet))  
  49.             newEntropy += prob * calcShannonEnt(subDataSet)       
  50.         infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy  
  51.         if (infoGain > bestInfoGain):           #compare this to the best gain so far  
  52.             bestInfoGain = infoGain             #if better than current best, set to best  
  53.             bestFeature = i  
  54.     return bestFeature                          #returns an integer  
2) 構建決策樹
  1. def majorityCnt(classList):                     #運用多數表決方法判定label不唯一時,葉子節點的分類  
  2.     classCount={}  
  3.     for vote in classList:  
  4.         if vote not in classCount.keys(): classCount[vote] = 0  
  5.         classCount[vote] += 1  
  6.     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)  
  7.     return sortedClassCount[0][0]               #返回label出現次數最多的所屬分類  
  8.   
  9. def createTree(dataSet,labels):                 #創建決策樹  
  10.     classList = [example[-1for example in dataSet]  
  11.     if classList.count(classList[0]) == len(classList): #list.count(list[0])返回指定位置0對應值list[0],出現的次數  
  12.         return classList[0]                     #stop splitting when all of the classes are equal  
  13.     if len(dataSet[0]) == 1:                    #stop splitting when there are no more features in dataSet  
  14.         return majorityCnt(classList)  
  15.     bestFeat = chooseBestFeatureToSplit(dataSet)  
  16.     bestFeatLabel = labels[bestFeat]  
  17.     myTree = {bestFeatLabel:{}}  
  18.     del(labels[bestFeat])                      #刪除已經使用的最佳劃分數據集特徵  
  19.     featValues = [example[bestFeat] for example in dataSet]  
  20.     uniqueVals = set(featValues)  
  21.     for value in uniqueVals:  
  22.         subLabels = labels[:]                  #copy all of labels, so trees don't mess up existing labels  
  23.         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)  
  24.     return myTree                              
  25.       
  26. def classify(inputTree,featLabels,testVec):    #使用決策樹分類函數進行分類  
  27.     firstStr = inputTree.keys()[0]  
  28.     secondDict = inputTree[firstStr]  
  29.     featIndex = featLabels.index(firstStr)  
  30.     key = testVec[featIndex]  
  31.     valueOfFeat = secondDict[key]  
  32.     if isinstance(valueOfFeat, dict):          #判斷是否爲字典類型的節點,如果是,則該節點爲判斷節點,否則,該節點爲葉子節點  
  33.         classLabel = classify(valueOfFeat, featLabels, testVec)  
  34.     else: classLabel = valueOfFeat  
  35.     return classLabel  
  36.   
  37. def storeTree(inputTree,filename): #利用pickle模塊存儲已經創建好的決策樹,以便後續使用中無需重新構建             
  38.     import pickle  
  39.     fw = open(filename,'w')  
  40.     pickle.dump(inputTree,fw)  
  41.     fw.close()  
  42.       
  43. def grabTree(filename):  
  44.     import pickle  
  45.     fr = open(filename)  
  46.     return pickle.load(fr)  

3) 程序運行截圖:(這裏用的pythonxy裏面的IPython(sh)交換環境)



實例測試

lenses.txt內容如下所示:

  1. young   myope   no  reduced no lenses  
  2. young   myope   no  normal  soft  
  3. young   myope   yes reduced no lenses  
  4. young   myope   yes normal  hard  
  5. young   hyper   no  reduced no lenses  
  6. young   hyper   no  normal  soft  
  7. young   hyper   yes reduced no lenses  
  8. young   hyper   yes normal  hard  
  9. pre myope   no  reduced no lenses  
  10. pre myope   no  normal  soft  
  11. pre myope   yes reduced no lenses  
  12. pre myope   yes normal  hard  
  13. pre hyper   no  reduced no lenses  
  14. pre hyper   no  normal  soft  
  15. pre hyper   yes reduced no lenses  
  16. pre hyper   yes normal  no lenses  
  17. presbyopic  myope   no  reduced no lenses  
  18. presbyopic  myope   no  normal  no lenses  
  19. presbyopic  myope   yes reduced no lenses  
  20. presbyopic  myope   yes normal  hard  
  21. presbyopic  hyper   no  reduced no lenses  
  22. presbyopic  hyper   no  normal  soft  
  23. presbyopic  hyper   yes reduced no lenses  
  24. presbyopic  hyper   yes normal  no lenses  


基於matplotlib模塊的python繪圖代碼如下所示:

  1. import matplotlib.pyplot as plt  
  2.   
  3. decisionNode = dict(boxstyle="sawtooth", fc="0.8")  
  4. leafNode = dict(boxstyle="round4", fc="0.8")  
  5. arrow_args = dict(arrowstyle="<-")  
  6.   
  7. def getNumLeafs(myTree):  
  8.     numLeafs = 0  
  9.     firstStr = myTree.keys()[0]  
  10.     secondDict = myTree[firstStr]  
  11.     for key in secondDict.keys():  
  12.         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes  
  13.             numLeafs += getNumLeafs(secondDict[key])  
  14.         else:   numLeafs +=1  
  15.     return numLeafs  
  16.   
  17. def getTreeDepth(myTree):  
  18.     maxDepth = 0  
  19.     firstStr = myTree.keys()[0]  
  20.     secondDict = myTree[firstStr]  
  21.     for key in secondDict.keys():  
  22.         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes  
  23.             thisDepth = 1 + getTreeDepth(secondDict[key])  
  24.         else:   thisDepth = 1  
  25.         if thisDepth > maxDepth: maxDepth = thisDepth  
  26.     return maxDepth  
  27.   
  28. def plotNode(nodeTxt, centerPt, parentPt, nodeType):  
  29.     createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',  
  30.              xytext=centerPt, textcoords='axes fraction',  
  31.              va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )  
  32.       
  33. def plotMidText(cntrPt, parentPt, txtString):  
  34.     xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  
  35.     yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]  
  36.     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)  
  37.   
  38. def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on  
  39.     numLeafs = getNumLeafs(myTree)      #this determines the x width of this tree  
  40.     depth = getTreeDepth(myTree)  
  41.     firstStr = myTree.keys()[0]         #the text label for this node should be this  
  42.     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)  
  43.     plotMidText(cntrPt, parentPt, nodeTxt)  
  44.     plotNode(firstStr, cntrPt, parentPt, decisionNode)  
  45.     secondDict = myTree[firstStr]  
  46.     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  
  47.     for key in secondDict.keys():  
  48.         if type(secondDict[key]).__name__=='dict':    #test to see if the nodes are dictonaires, if not they are leaf nodes     
  49.             plotTree(secondDict[key],cntrPt,str(key)) #recursion  
  50.         else:                                         #it's a leaf node print the leaf node  
  51.             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW  
  52.             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)  
  53.             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  
  54.     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD  
  55. #if you do get a dictonary you know it's a tree, and the first element will be another dict  
  56.   
  57. def createPlot(inTree):  
  58.     fig = plt.figure(1, facecolor='white')  
  59.     fig.clf()  
  60.     axprops = dict(xticks=[], yticks=[])  
  61.     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks  
  62.     #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses   
  63.     plotTree.totalW = float(getNumLeafs(inTree))  
  64.     plotTree.totalD = float(getTreeDepth(inTree))  
  65.     plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;  
  66.     plotTree(inTree, (0.5,1.0), '')  
  67.     plt.show()  
  68.   
  69. #def createPlot():  
  70. #    fig = plt.figure(1, facecolor='white')  
  71. #    fig.clf()  
  72. #    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses   
  73. #    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)  
  74. #    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)  
  75. #    plt.show()  
  76.   
  77. def retrieveTree(i):  
  78.     listOfTrees =[{'no surfacing': {0'no'1: {'flippers': {0'no'1'yes'}}}},  
  79.                   {'no surfacing': {0'no'1: {'flippers': {0: {'head': {0'no'1'yes'}}, 1'no'}}}}  
  80.                   ]  
  81.     return listOfTrees[i]  
  82.   
  83. #createPlot(thisTree)  



發佈了36 篇原創文章 · 獲贊 51 · 訪問量 19萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章