ML入門2.0 -- 手寫決策樹(Decision Tree)

決策樹簡介

Decision Tree 中文稱爲決策樹,是ML中第二種十分經典的算法,顧名思義其算法結構爲樹形結構 ,與上一篇博客中介紹的KNN 類似都可以用來解決分類問題的算法。 決策樹由下面三種元素構成:

  1. 根結點 :樣本數據的全集
  2. 內部節點 : 按不同特徵屬性劃分的集合
  3. 葉節點 : 決策的結果
    在這裏插入圖片描述
    決策樹是一種不完全的歸納法,通過層層推理來實現最終的分類,其思想可以類比於程序設計中If-else 分支判斷結構,但是if else 是程序員在已知分類結果的基礎上設計出的一種固定的判斷模式;而決策樹則是根據不同的數據集合訓練自動給出相應最優的判斷模式,實現最終的分類。決策樹是最簡單的機器學習算法,它易於實現,可解釋性強,完全符合人類的直觀思維,有着廣泛的應用。目前常用的決策樹算法有三種:ID3C4.5CART

決策樹原理

這裏我們主要講解ID3算法所用的原理,首先需要給出一個定義:信息熵 (entropy),這裏表示系統(某個事件)的混亂程度熵越大混亂程度越大熵的變化可以看做是信息增益,決策樹ID3算法的核心思想是以信息增益度量屬性(分類特徵)選擇,選擇分類後信息增益最大的**屬性(特徵)**進行分類。

首先給出信息熵的計算公式:
隨機變量Y的信息熵爲(Y爲決策變量)1

H(Y)=i=1np(yi)log2(1/p(yi))=i=1np(yi)log2p(yi)H(Y)=\sum_{i=1}^{n}p(y_i)\log_2(1/p(y_i))=-\sum_{i=1}^{n}p(y_i)\log_2p(y_i)

隨機變量Y關於X的條件信息熵爲(X爲條件變量)

H(YX)=i=1mp(xi)H(YX=xi)=i=1mp(xi)log2p(yixi)H(Y|X) = \sum_{i=1}^{m}p(x_i)H(Y|X=x_i)=-\sum_{i=1}^{m}p(x_i)\log_2p(y_i|x_i)

信息增益的公式:

H(Y)H(YX)H(Y)-H(Y|X)

這裏給出一點關於信息熵公式的推導:
首先我們需要知道信息得分測量方法,那麼我們不妨先考慮測量質量的方法:很久以前一些聰慧的原始人類,選定了一個物體A將其作爲參照物,把物體A的質量就稱爲1千克(KG) ,那麼測量其他物體(這裏假設測定物體B)的質量時就是觀察物體B相當於幾個 物體A,這裏的幾個就是KG前面的係數,那麼物體B的質量就等於 幾個KG ;那麼**幾個(n)**的計算公式很簡單:n = M(B)/ M(A)

那麼現在我們再來思考信息熵的測量:信息的不確定性就是信息熵的量度,先假定一個參照事件A的不確定性,類比質量測定,待測事件B的不確定性的計算就是看待測事件B的不確定性相當於多少個參照事件A的不確定性,這裏的多少個就是所謂的信息量。但是不同於質量的是,B的不確定性是又多個A的不確定性相乘得到,舉個例子:拋一枚硬幣有兩種不確定的情況:正面反面 ,但是拋三枚硬幣是的情況是 222 = 2^3 = 8種情況,所以信息的多少個(n)的計算公式就要用指數運算的反函數對數運算來得到,即 n= log2(B);以上闡述是針對待測事件B所有可能情況都是等概率時,那當B的可能情況之間的概率不相等的情況時又該如何測量信息熵? 也不難思考:就是分別測量每種可能情況的信息熵然後乘以該情況的概率 最後再將所有結果相加。所以重點就變成求不同概率情況的信息熵,而概率的倒數恰好就是該概率所對應情況所包含等概率事件的情況個數,例如某事件概率爲P=1/10;那麼1/P=10就是該事件所含等概率的額情況個數;所以這種情況的信息熵就爲n=Plog2(1/P),同理可以求得其他情況的信息熵,最後求和就是上文給出的總信息熵的計算公式 1。
文字解釋太枯燥,給出一個解釋信息熵的公式的視頻:

【學習觀11】爲什麼信息還有單位?如何計算信息量?

最後舉個計算信息熵的小例子,假定有四個射擊高手:A; B; C; D, 他們獲勝的概率分別爲P(A)=1/2; P(B)=1/4; P©=P(D)=1/8
設Y爲確定哪一位高手獲勝
H(X)=P(A)log2(1P(A))+P(B)log2(1P(B))+P(C)log2(1P(C))+P(D)log2(1P(D))=12log2(2)+14log2(4)+18log2(8)+18log2(8)=12+12+38+38=47bitsH(X)=P(A)\log_2(\frac{1}{P(A)})+P(B)\log_2(\frac{1}{P(B)})+P(C)\log_2(\frac{1}{P(C)})+P(D)\log_2(\frac{1}{P(D)})=\frac{1}{2}\log_2(2)+\frac{1}{4}\log_2(4)+\frac{1}{8}\log_2(8)+\frac{1}{8}\log_2(8)=\frac{1}{2}+\frac{1}{2}+\frac{3}{8}+\frac{3}{8}=\frac{4}{7}bits

決策樹舉例

這裏我們給出一個數據集weatherplay tennis 以此來構建決策樹
在這裏插入圖片描述

第一種辦法採用一個Machine Learning的軟件weka來構造

實驗截圖:

數據集導入

在這裏插入圖片描述

使用ID3分類算法

在這裏插入圖片描述

使用C4.5(這裏是J48)

在這裏插入圖片描述
PS:紅框位置爲生成的決策樹
weka使用教程

手寫ID3

簡單的決策樹主要分爲兩個步驟:
訓練: 建樹 (建立模型)
測試: 用樹 (使用模型)
演示使用的數據集依然爲上文中的weather數據集
Func1: Loaddata() 製作數據集,加載數據

#Step 1. Load the weather data
def Loaddata():
    '''
    make the weather dataset
    :return: weatherData(原始數據); featureName(天氣特徵); classValues(分類結果:是否打球)
    '''
    weatherData = [['Sunny','Hot','High','FALSE','N'],
        ['Sunny','Hot','High','TRUE','N'],
        ['Overcast','Hot','High','FALSE','P'],
        ['Rain','Mild','High','FALSE','P'],
        ['Rain','Cool','Normal','FALSE','P'],
        ['Rain','Cool','Normal','TRUE','N'],
        ['Overcast','Cool','Normal','TRUE','P'],
        ['Sunny','Mild','High','FALSE','N'],
        ['Sunny','Cool','Normal','FALSE','P'],
        ['Rain','Mild','Normal','FALSE','P'],
        ['Sunny','Mild','Normal','TRUE','P'],
        ['Overcast','Mild','High','TRUE','P'],
        ['Overcast','Hot','Normal','FALSE','P'],
        ['Rain','Mild','High','TRUE','N']]

    featureName = ['Outlook', 'Temperature', 'Humidity', 'Windy']
    classValues = ['P', 'N']

    return weatherData, featureName, classValues

Func2:calcShannonEnt(paraDataSet)計算香農熵

def calcShannonEnt(paraDataSet):
    '''
    計算給定數據集的香濃熵
    :param paraDataSet: 給定數據集
    :return: shannonEnt
    '''
    numInstances = len(paraDataSet)  # numInstances:當前給定數據集中數據的個數
    labelCounts = {}
    for featureVec in paraDataSet:  # featureVec:數據集中的單個數據
        tempLabel = featureVec[-1]
        if tempLabel not in labelCounts.keys():
            labelCounts[tempLabel] = 0
        labelCounts[tempLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts.keys():
        prob = float(labelCounts[key])/numInstances
        shannonEnt -= prob * math.log2(prob)

    return shannonEnt

Func3:splitDataSet(dataSet, axis, value) 劃分子數據集

def splitDataSet(dataSet, axis, value):
    '''
    劃分該出特徵下層的數據集
    :param dataSet: 數據集
    :param axis: 第幾個特徵
    :param value: 該特徵的值
    :return: resultDataSet:分類後的數據集
    '''
    resultDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
    # 因爲劃分當前數據的子集所以去掉當前數據集的分類標準(特徵)
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            resultDataSet.append(reducedFeatVec)
    return resultDataSet

Func4: chooseBestFeatureToSplit(dataSet) 選出最優的劃分特徵

def chooseBestFeatureToSplit(dataSet):
    '''
    選擇出最好的特徵進行劃分子數據集
    :param dataSet:數據集
    :return: bestFeature:決策出的劃分效果最好(信息增益最大的特徵)
    '''
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature  = -1
    for i in range(numFeatures):
#把第i個屬性的所有取值篩選出來組成一個list
        featList = [data[i] for data in dataSet]
#去除list中的重複值
        uniqueVals = set(featList)
        newEntropy = 0.0
#計算條件信息熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob*calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i

    return bestFeature

Func5: majorityCnt(classList) 衝突剪枝

def majorityCnt(classList):
    '''
    :param classList:
    :return:投票決定的類別
    '''
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=itemgetter(1), reverse=True)
    return sortedClassCount

Func6:creatTree(dataSet, paraFeatureName)建立決策樹

def creatTree(dataSet, paraFeatureName):
    '''
    建樹
    :param dataSet:數據集
    :param paraFeatureName:數據集的屬性名稱
    :return: 遞歸創建完成的決策樹
    '''
    featureName = paraFeatureName.copy()  # 防止後面原本的數據修改導致的分類出錯
    classList = [example[-1] for example in dataSet]

#  如果當前的label只有一種類別則說明該子集已經完善
    if classList.count(classList[0]) == len(classList):
        return classList[0]

#  如果遇到屬性一致但是結果不同的衝突情形(無分類屬性可用),選擇佔比大的
    if len(dataSet) == 1:
        return majorityCnt(classList)

    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatureName = featureName[bestFeat]
    myTree = {bestFeatureName:{}}
    del(featureName[bestFeat])
    featvalue = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featvalue)
    for value in uniqueVals:
        subfeatureName = featureName[:]
        myTree[bestFeatureName][value] = creatTree(splitDataSet(dataSet, bestFeat, value), subfeatureName)
    return myTree

Func7: id3Classify(paraTree, paraTestingSet, featureNames, classValues) ID3分類器

def id3Classify(paraTree, paraTestingSet, featureNames, classValues):
    '''
    ID3分類器
    :param paraTree: 已生成的決策樹
    :param paraTestingSet: 測試集
    :param featureNames: 特徵名稱
    :param classValues: 分類類型值
    :return: 正確率
    '''
    tempCorrect = 0.0
    tempTotal = len(paraTestingSet)
    tempPrediction = classValues[0]
    for featureVector in paraTestingSet:
        print("Instance: ", featureVector)
        tempTree = paraTree
        while True:
            for feature in featureNames:
                try:
                    tempTree[feature]
                    splitFeature = feature
                    break
                except:
                    i = 1
            attributeValue = featureVector[featureNames.index(splitFeature)]
            print(splitFeature, " = ", attributeValue)

            tempPrediction = tempTree[splitFeature][attributeValue]
            if tempPrediction in classValues:
                break
            else:
                tempTree = tempPrediction
        print("Prediction = ", tempPrediction)
        if featureVector[-1] == tempPrediction:
            tempCorrect += 1
    return tempCorrect/tempTotal

Func8:STID3Test() 測試程序

def STID3Test():
    weatherData, featureName, classValues = Loaddata()
    tempTree = creatTree(weatherData, featureName)
    print(tempTree)
    print("Before classification, feature names = ", featureName)
    tempAccuracy = id3Classify(tempTree, weatherData, featureName, classValues)
    print("The accuracy of ID3 classifier is {}%".format(tempAccuracy*100))

運行結果:

在這裏插入圖片描述

完整版程序見githhub

github地址

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章