【機器學習】 - 決策樹(西瓜數據集)

周志華的西瓜書《決策樹》部分的代碼實現

#利用決策樹算法,對mnist數據集進行測試
import numpy as np

#計算熵
def calcEntropy(dataSet):
    mD = len(dataSet)
    dataLabelList = [x[-1] for x in dataSet]
    dataLabelSet = set(dataLabelList)
    ent = 0
    for label in dataLabelSet:
        mDv = dataLabelList.count(label)
        prop = float(mDv) / mD
        ent = ent - prop * np.math.log(prop, 2)

    return ent

# # 拆分數據集
# # index - 要拆分的特徵的下標
# # feature - 要拆分的特徵
# # 返回值 - dataSet中index所在特徵爲feature,且去掉index一列的集合
def splitDataSet(dataSet, index, feature):
    splitedDataSet = []
    mD = len(dataSet)
    for data in dataSet:
        if(data[index] == feature):
            sliceTmp = data[:index]
            sliceTmp.extend(data[index + 1:])
            splitedDataSet.append(sliceTmp)
    return splitedDataSet

#根據信息增益 - 選擇最好的特徵
# 返回值 - 最好的特徵的下標
def chooseBestFeature(dataSet):
    entD = calcEntropy(dataSet)
    mD = len(dataSet)
    featureNumber = len(dataSet[0]) - 1
    maxGain = -100
    maxIndex = -1
    for i in range(featureNumber):
        entDCopy = entD
        featureI = [x[i] for x in dataSet]
        featureSet = set(featureI)
        for feature in featureSet:
            splitedDataSet = splitDataSet(dataSet, i, feature)  # 拆分數據集
            mDv = len(splitedDataSet)
            entDCopy = entDCopy - float(mDv) / mD * calcEntropy(splitedDataSet)
        if(maxIndex == -1):
            maxGain = entDCopy
            maxIndex = i
        elif(maxGain < entDCopy):
            maxGain = entDCopy
            maxIndex = i

    return maxIndex

# 尋找最多的,作爲標籤
def mainLabel(labelList):
    labelRec = labelList[0]
    maxLabelCount = -1
    labelSet = set(labelList)
    for label in labelSet:
        if(labelList.count(label) > maxLabelCount):
            maxLabelCount = labelList.count(label)
            labelRec = label
    return labelRec

#生成樹
def createDecisionTree(dataSet, featureNames):
    labelList = [x[-1] for x in dataSet]
    if(len(dataSet[0]) == 1): #沒有可劃分的屬性了
        return mainLabel(labelList)  #選出最多的label作爲該數據集的標籤
    elif(labelList.count(labelList[0]) == len(labelList)): # 全部都屬於同一個Label
        return labelList[0]

    bestFeatureIndex = chooseBestFeature(dataSet)
    bestFeatureName = featureNames.pop(bestFeatureIndex)
    myTree = {bestFeatureName: {}}
    featureList = [x[bestFeatureIndex] for x in dataSet]
    featureSet = set(featureList)
    for feature in featureSet:
        featureNamesNext = featureNames[:]
        splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)
        myTree[bestFeatureName][feature] = createDecisionTree(splitedDataSet, featureNamesNext)
    return myTree

#讀取西瓜數據集2.0
def readWatermelonDataSet():
    ifile = open("周志華_西瓜數據集2.txt")
    featureName = ifile.readline()  #表頭
    labels = (featureName.split(' ')[0]).split(',')
    lines = ifile.readlines()
    dataSet = []
    for line in lines:
        tmp = line.split('\n')[0]
        tmp = tmp.split(',')
        dataSet.append(tmp)

    return dataSet, labels

def main():
    #讀取數據
    dataSet, featureNames = readWatermelonDataSet()
    print(createDecisionTree(dataSet, featureNames))

if __name__ == "__main__":
    main()

最後輸出的決策樹是:
{‘紋理’: {‘模糊’: ‘否’, ‘清晰’: {‘根蒂’: {‘稍蜷’: {‘色澤’: {‘烏黑’: {‘觸感’: {‘硬滑’: ‘是’, ‘軟粘’: ‘否’}}, ‘青綠’: ‘是’}}, ‘蜷縮’: ‘是’, ‘硬挺’: ‘否’}}, ‘稍糊’: {‘觸感’: {‘硬滑’: ‘否’, ‘軟粘’: ‘是’}}}}

畫出來是這個樣子的:
在這裏插入圖片描述

這個地方和書上不太一樣。
後來參考了一篇CSDN文章
說是需要補全決策樹
在這裏插入圖片描述
後來又仔細看了僞代碼
在這裏插入圖片描述
主要是對畫紅線處的理解。
這裏的“每一個值”到底是原始數據集的?還是分割後的數據集的
上面的代碼是後者,書上是前者

把createDecisionTree() 和 readWatermelonDataSet()函數修改爲下面的:

#生成決策樹
# featureNamesSet 是featureNames取值的集合
# labelListParent 是父節點的標籤列表
def createFullDecisionTree(dataSet, featureNames, featureNamesSet, labelListParent):
    labelList = [x[-1] for x in dataSet]
    if(len(dataSet) == 0):
        return mainLabel(labelListParent)
    elif(len(dataSet[0]) == 1): #沒有可劃分的屬性了
        return mainLabel(labelList)  #選出最多的label作爲該數據集的標籤
    elif(labelList.count(labelList[0]) == len(labelList)): # 全部都屬於同一個Label
        return labelList[0]

    bestFeatureIndex = chooseBestFeature(dataSet)
    bestFeatureName = featureNames.pop(bestFeatureIndex)
    myTree = {bestFeatureName: {}}
    featureList = featureNamesSet.pop(bestFeatureIndex)
    featureSet = set(featureList)
    for feature in featureSet:
        featureNamesNext = featureNames[:]
        featureNamesSetNext = featureNamesSet[:][:]
        splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)
        myTree[bestFeatureName][feature] = createFullDecisionTree(splitedDataSet, featureNamesNext, featureNamesSetNext, labelList)
    return myTree


#讀取西瓜數據集2.0
def readWatermelonDataSet():
    ifile = open("周志華_西瓜數據集2.txt")
    featureName = ifile.readline()  #表頭
    featureNames = (featureName.split(' ')[0]).split(',')
    lines = ifile.readlines()
    dataSet = []
    for line in lines:
        tmp = line.split('\n')[0]
        tmp = tmp.split(',')
        dataSet.append(tmp)

    #獲取featureNamesSet
    featureNamesSet = []
    for i in range(len(dataSet[0]) - 1):
        col = [x[i] for x in dataSet]
        colSet = set(col)
        featureNamesSet.append(list(colSet))

    return dataSet, featureNames, featureNamesSet

現在和書上的一樣了
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章