機器學習實戰--決策樹分類

關於決策樹的講解,在另一篇博客中我給過介紹,有興趣的可以看下相關的內容,建議不瞭解原理的先了解決策樹的原理,弄清算法的流程和幾個基本概念。決策樹分類算法
以下師python的決策樹實現,採用的是信息增益來選取最好的屬性,即 ID3算法:
參考機器學習實戰,在實踐中,給了一點自己的註釋,希望能幫助大家理解。

# encoding:utf-8
from math import log
import operator


def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels


# 計算熵
def calcShannoEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLable = featVec[-1]
        if currentLable not in labelCounts.keys():
            labelCounts[currentLable] = 0
        labelCounts[currentLable] += 1
    shannoEnt = 0.0
    for key in labelCounts:
        prof = float(labelCounts[key]) / numEntries  # 求得P(i)
        shannoEnt -= prof * log(prof, 2)  # 求-log2P(i)的期望值
    return shannoEnt


# 根據屬性的下標和屬性的值對數據集進行劃分(這個方法和我們給出的數據是高度適配的,即每行的數據最後一個是分類標籤,之前的每列代表一個屬性,數據集不同的話,處理的過程也不同)
def splitDataSet(dataSet, axix, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axix] == value:
            # 按照指定的列劃分數據後,劃分後的數據需要去除該列的屬性
            reducedFeatVec = featVec[:axix]  # 不包括axix列
            reducedFeatVec.extend(featVec[axix + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # 屬性的數量
    baseEntropy = calcShannoEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]  # 第i個屬性,即第i列的所有值(包含重複)
        uniqueVals = set(featList)  # 去重
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)  # 劃分
            prob = len(subDataSet) / float(len(dataSet))  # 計算每個類別的概率
            newEntropy += prob * calcShannoEnt(subDataSet)  # 計算每個類別的熵
        infoGain = baseEntropy - newEntropy  # 計算根據該屬性劃分後的信息增益率
        if infoGain > bestInfoGain:  # 按照決策樹劃分的原則,會選擇信息增益最大的屬性進行劃分
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


def majorityCnt(classList):
    """
    取佔據大部分的類別
    :param classList: 類別列表
    :return:
    """
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]  # 得到所有的類別
    if classList.count(classList[0]) == len(classList):  # 類別完全相同,則停止劃分
        return classList[0]  # 返回類型信息
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 求得最好的劃分屬性下標
    bestFeatLabel = labels[bestFeat]  # 最好的劃分屬性名
    myTree = {bestFeatLabel: {}}
    del labels[bestFeat]
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:    # 使用最好的屬性劃分後能夠得到一些子數據集,對這些數據集繼續進行劃分
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree


if __name__ == '__main__':
    myDat, labels = createDataSet()
    print createTree(myDat,labels)

運行的結果如下:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

那麼訓練好的決策樹如何使用呢?
一般而言,訓練好的決策樹是一種知識,會作爲知識存儲起來,以便進行分類時直接使用。

下面進行對待分類的數據進行分類:

在上面的代碼中加入下面的測試方法:

# 對測試數據進行分類
def classify(inputTree, featLabels, testVec):
    firstStr = inputTree.keys()[0]  # 當前劃分的最好屬性
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)  # 將標籤字符串轉換爲索引
    # 層次遍歷劃分屬性對應的劃分值,判斷測試數據是哪種
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
            break
    return classLabel

相應的main方法更改爲如下內容:

if __name__ == '__main__':
    myDat, labels = createDataSet()
    templabels = labels[:]
    mytree = createTree(myDat,templabels)
    testVec = [1,0]
    print classify(mytree,labels,testVec)

說明:createTree會更改labels,因此,在使用createTree方法時,需要傳入labels複製出的templabels。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章