1 決策樹的構造
在構造決策樹之前,我們需要解決的第一個問題就是當前數據集上哪個特徵在劃分數據集上起決定性作用。爲了劃分出最好的效果,我們必須評估每個特徵。
一般的劃分數據採用二分法,而本文采用ID3算法劃分數據集
上面的表包含了5個海洋生物的數據,兩個特徵以及把這些動物分成魚類和非魚類,現在我們決定是依據第一個特徵還是第二個特徵來劃分數據。
2. 信息增益和熵
在講如何劃分之前,我們先講下信息增益。所謂信息增益就是在劃分數據集之前和之後發生的變化稱爲信息增益。集合信息的度量稱爲香農熵或者熵。
在1948年,香農引入了信息熵,將其定義爲離散隨機事件出現的概率,一個系統越是有序,信息熵就越低,反之一個系統越是混亂,它的信息熵就越高。所以信息熵可以被認爲是系統有序化程度的一個度量。
假如一個隨機變量的取值爲,每一種取到的概率分別是,那麼 的熵定義爲
意思是一個變量的變化情況可能越多,那麼它攜帶的信息量就越大。
對於分類系統來說,類別是變量,它的取值是,而每一個類別出現的概率分別是
而這裏的就是類別的總數,此時分類系統的熵就可以表示爲
關於熵和信息增益詳見http://blog.csdn.net/acdreamers/article/details/44661149
3.Python實現決策樹
from math import log
import operator
def calShannonEnt(dataSet):
numEntries=len(dataSet)
lableCount={}
for featVec in dataSet:
currentlabel=featVec[-1]
# print(currentlabel)
if currentlabel not in lableCount.keys():
#print(lableCount)
lableCount[currentlabel]=0
lableCount[currentlabel]+=1
shannonEnt=0.0
for key in lableCount:
prob=float(lableCount[key])/numEntries
#print(prob)
shannonEnt-=prob*log(prob,2)
return shannonEnt
def createDataSet():
dataSet=[[1,1,'yes'],
[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
labels=['no surfacing','flippers']
return dataSet, labels
def splitDataSet(dataSet,axis,value):
retDataSet=[]
for featVec in dataSet:
if featVec[axis]==value:
rFeatVec=featVec[:axis]
rFeatVec.extend(featVec[axis+1:])
retDataSet.append(rFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 #the last column is used for the labels
baseEntropy = calShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatures): #iterate over all the features
featList = [example[i] for example in dataSet]
#print(featList)
#create a list of all the examples of this feature
uniqueVals = set(featList)
#print(uniqueVals)#get a set of unique values
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy
if (infoGain > bestInfoGain): #compare this to the best gain so far
bestInfoGain = infoGain #if better than current best, set to best
bestFeature = i
return bestFeature
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
#print(classList)
if classList.count(classList[0]) == len(classList):
return classList[0]#stop splitting when all of the classes are equal
if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
print(bestFeatLabel)
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
print(uniqueVals)
for value in uniqueVals:
subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
4.結果