机器学习实战--决策树分类

关于决策树的讲解,在另一篇博客中我给过介绍,有兴趣的可以看下相关的内容,建议不了解原理的先了解决策树的原理,弄清算法的流程和几个基本概念。决策树分类算法
以下师python的决策树实现,采用的是信息增益来选取最好的属性,即 ID3算法:
参考机器学习实战,在实践中,给了一点自己的注释,希望能帮助大家理解。

# encoding:utf-8
from math import log
import operator


def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels


# 计算熵
def calcShannoEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLable = featVec[-1]
        if currentLable not in labelCounts.keys():
            labelCounts[currentLable] = 0
        labelCounts[currentLable] += 1
    shannoEnt = 0.0
    for key in labelCounts:
        prof = float(labelCounts[key]) / numEntries  # 求得P(i)
        shannoEnt -= prof * log(prof, 2)  # 求-log2P(i)的期望值
    return shannoEnt


# 根据属性的下标和属性的值对数据集进行划分(这个方法和我们给出的数据是高度适配的,即每行的数据最后一个是分类标签,之前的每列代表一个属性,数据集不同的话,处理的过程也不同)
def splitDataSet(dataSet, axix, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axix] == value:
            # 按照指定的列划分数据后,划分后的数据需要去除该列的属性
            reducedFeatVec = featVec[:axix]  # 不包括axix列
            reducedFeatVec.extend(featVec[axix + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # 属性的数量
    baseEntropy = calcShannoEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]  # 第i个属性,即第i列的所有值(包含重复)
        uniqueVals = set(featList)  # 去重
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)  # 划分
            prob = len(subDataSet) / float(len(dataSet))  # 计算每个类别的概率
            newEntropy += prob * calcShannoEnt(subDataSet)  # 计算每个类别的熵
        infoGain = baseEntropy - newEntropy  # 计算根据该属性划分后的信息增益率
        if infoGain > bestInfoGain:  # 按照决策树划分的原则,会选择信息增益最大的属性进行划分
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


def majorityCnt(classList):
    """
    取占据大部分的类别
    :param classList: 类别列表
    :return:
    """
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]  # 得到所有的类别
    if classList.count(classList[0]) == len(classList):  # 类别完全相同,则停止划分
        return classList[0]  # 返回类型信息
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 求得最好的划分属性下标
    bestFeatLabel = labels[bestFeat]  # 最好的划分属性名
    myTree = {bestFeatLabel: {}}
    del labels[bestFeat]
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:    # 使用最好的属性划分后能够得到一些子数据集,对这些数据集继续进行划分
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree


if __name__ == '__main__':
    myDat, labels = createDataSet()
    print createTree(myDat,labels)

运行的结果如下:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

那么训练好的决策树如何使用呢?
一般而言,训练好的决策树是一种知识,会作为知识存储起来,以便进行分类时直接使用。

下面进行对待分类的数据进行分类:

在上面的代码中加入下面的测试方法:

# 对测试数据进行分类
def classify(inputTree, featLabels, testVec):
    firstStr = inputTree.keys()[0]  # 当前划分的最好属性
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)  # 将标签字符串转换为索引
    # 层次遍历划分属性对应的划分值,判断测试数据是哪种
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
            break
    return classLabel

相应的main方法更改为如下内容:

if __name__ == '__main__':
    myDat, labels = createDataSet()
    templabels = labels[:]
    mytree = createTree(myDat,templabels)
    testVec = [1,0]
    print classify(mytree,labels,testVec)

说明:createTree会更改labels,因此,在使用createTree方法时,需要传入labels复制出的templabels。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章