《統計學習方法》 決策樹 CART生成算法 分類樹 Python實現

先貼一個全代碼。後面有時間再來解釋。

# --*-- coding:utf-8 --*--
import numpy as np

class Node: #結點
    def __init__(self, data = None, lchild = None, rchild = None):
        self.data = data
        self.child = {} #需要用字典的key來做邊的值('是','否')

class DecisionTree4Cart:    #分類與迴歸樹
    def create(self, dataSet, labels):
        featureSet = self.createFeatureSet(dataSet)
        def createBranch(dataSet, featureSet):
            classLabel = [row[-1] for row in dataSet]    #按列讀取標籤
            node = Node()
            if (len(set(classLabel)) == 1):   #說明已經沒有需要劃分的了
                node.data = classLabel[0]
                node.child = None
                return node
            minGini = 1.1    #    #不會超過1
            minGiniIndex = -1
            minGiniFeature = None
            for key in featureSet:
                feature = featureSet[key]
                for x in feature:
                    gini = self.calcConditionalGini(dataSet, key, x) #計算基尼指數
                    print(gini)
                    if (minGini > gini):    #比較得出最小的基尼指數
                        minGini = gini
                        minGiniIndex = key
                        minGiniFeature = x
            node.data = labels[minGiniIndex]
            subFeatureSet = featureSet.copy()
            del subFeatureSet[minGiniIndex]#刪除特徵集(不再作爲劃分依據)
            subDataSet1 = [row for row in dataSet if row[minGiniIndex] == minGiniFeature]
            node.child[minGiniFeature] = createBranch(subDataSet1, subFeatureSet)
            subDataSet2 = [row for row in dataSet if row[minGiniIndex] != minGiniFeature]
            node.child["other"] = createBranch(subDataSet2, subFeatureSet)
            return node
        return createBranch(dataSet, featureSet)

    def calcConditionalGini(self, dataSet, featureIndex, value):   #計算基尼指數
        conditionalGini = 0
        """
        可以看出下面的代碼使按公式5.25來的嗎?
        """
        subDataSet1 = [row for row in dataSet if row[featureIndex] == value]    #按值劃分數據集,這是第一個數據集
        conditionalGini += len(subDataSet1) / float(len(dataSet)) * self.calcGini(subDataSet1)
        subDataSet2 = [row for row in dataSet if row[featureIndex] != value]    #第二個數據集
        conditionalGini += len(subDataSet2) / float(len(dataSet)) * self.calcGini(subDataSet2)
        return conditionalGini

    def calcGini(self, dataSet, featureKey = -1):   #計算基尼指數
        classLabel = [row[featureKey] for row in dataSet]
        labelSet = set(classLabel)  #類別的集合
        gini = 1
        for x in labelSet:  #此爲遍歷類標籤的類別,計算熵
            gini -= ((classLabel.count(x) / float(len(dataSet))) ** 2) #按公式5.24來
        return gini

    def preOrder(self, node, depth = 0):    #先序遍歷
        if (node != None):
            print(node.data, depth)
            if (node.child != None):
                for key in node.child:
                    print(key)
                    self.preOrder(node.child[key], depth + 1)

    def createFeatureSet(self, dataSet):    #創建特徵集
        featureSet = {}
        m, n = np.shape(dataSet)
        for i in range(n - 1):  #按列來遍歷,n-1代表不存入類別的特徵
            column = list(set([row[i] for row in dataSet]))    #按列提取數據
            featureSet[i] = column   #每一行就是每一維的特徵值
        return featureSet

    def classify(self, node, labels, testVec):  #類別判斷
        while node != None:
            if (node.data in labels):   #用來判斷是否內部結點,內部結點就繼續往下找
                index = labels.index(node.data) #非根結點意味着是根據某個特徵劃分的,找出該特徵的索引
                x = testVec[index]
                for key in node.child:  #遍歷結點孩子字典,用key來做權值來判斷該往左結點移動還是右節點
                    if x == key:
                        node = node.child[key]
                        break
                else:
                    node = node.child['other']
            else:
                break
        return node.data
if __name__ == '__main__':
    dataSet = [['青年', '否', '否', '一般', '否'],
           ['青年', '否', '否', '好', '否'],
           ['青年', '是', '否', '好', '是'],
           ['青年', '是', '是', '一般', '是'],
           ['青年', '否', '否', '一般', '否'],
           ['中年', '否', '否', '一般', '否'],
           ['中年', '否', '否', '好', '否'],
           ['中年', '是', '是', '好', '是'],
           ['中年', '否', '是', '非常好', '是'],
           ['中年', '否', '是', '非常好', '是'],
           ['老年', '否', '是', '非常好', '是'],
           ['老年', '否', '是', '好', '是'],
           ['老年', '是', '否', '好', '是'],
           ['老年', '是', '否', '非常好', '是'],
           ['老年', '否', '否', '一般', '否']]
    labels = ['年齡', '有工作', '有自己的房子', '信貸情況']
    tree = DecisionTree4Cart()
    node = tree.create(dataSet, labels)
    tree.preOrder(node)
    for dataLine in dataSet:
        print(tree.classify(node, labels, dataLine))

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章