機器學習 - 決策樹實現

1. 決策樹原理篇

ID3算法:http://blog.csdn.net/zk_j1994/article/details/74066406

C4.5算法:http://blog.csdn.net/zk_j1994/article/details/74560278

CART算法:http://blog.csdn.net/zk_j1994/article/details/74606412


2. 決策樹實現

決策樹實現的算法步驟:

1)計算當前數據集的熵;

2)確定當前節點的劃分屬性,具體就是計算當前數據中每一個屬性的最大信息增益,選取信息增益最大的列作爲劃分屬性。

3)遞歸建樹,回到1),直到最大信息增益等於0

# -*- coding: utf-8 -*-
# 基於CART的決策樹算法實現
from math import log2

class Node:
    def __init__(self, col=-1, value=None, results=None, 
                 tb=None, fb=None):
        self.col = col              # 劃分data的屬性(列)
        self.value = value          # col中的某個值
        self.results = results      # 保存當前分支的結果, 除葉子節點外, 其他節點均爲None
        self.true_branch = tb       # true分支
        self.false_branch = fb      # false分支

def load_data():
    data=[['slashdot','USA','yes',18,'None'],
          ['google','France','yes',23,'Premium'],
          ['digg','USA','yes',24,'Basic'],
          ['kiwitobes','France','yes',23,'Basic'],
          ['google','UK','no',21,'Premium'],
          ['(direct)','New Zealand','no',12,'None'],
          ['(direct)','UK','no',21,'Basic'],
          ['google','USA','no',24,'Premium'],
          ['slashdot','France','yes',19,'None'],
          ['digg','USA','no',18,'None'],
          ['google','UK','no',18,'None'],
          ['kiwitobes','UK','no',19,'None'],
          ['digg','New Zealand','yes',12,'Basic'],
          ['slashdot','UK','no',21,'None'],
          ['google','UK','yes',18,'Basic'],
          ['kiwitobes','France','yes',19,'Basic']]
    return data

def _divide_set(data, column, value):
    """ 對data在column上按照value拆分, 能夠處理數值型數據和標稱型數據 """
    split_function = None
    if isinstance(value, int) or isinstance(value, float):
        split_function = lambda row: row[column] >= value
    else:
        split_function = lambda row: row[column] == value
    
    # 將數據集拆分爲兩個集合, 並返回
    set1 = [row for row in data if split_function(row)]
    set2 = [row for row in data if not split_function(row)]
    return (set1, set2)

def _label_counts(data):
    """ 對data最後一列進行統計 """
    result = {}
    for row in data:
        r = row[len(row) - 1]
        result[r] = result.get(r, 0) + 1
    return result

def _cal_entropy(data):
    """ 計算當前數據集的熵 """
    label_count = _label_counts(data)
    counts = sum(label_count.values())
    
    entropy = 0
    for key in label_count.keys():
        p = label_count[key] / counts
        entropy += -(p * log2(p))
    return entropy

def build_tree(data):
    """ 遞歸建樹 """
    if len(data) == 0:  return Node()
    
    # 計算當前data的熵
    current_entropy = _cal_entropy(data)
    
    # 定義一些變量以記錄最佳拆分條件
    best_gain = 0.0
    best_criteria = None
    best_sets = None
    
    column_count = len(data[0]) - 1
    for col in range(column_count):
        # 統計當前列col中所有不同的值
        col_values = set()
        for row in data:
            col_values.add(row[col])
            
        # 對當前列col進行劃分
        for col_value in col_values:
            set1, set2 = _divide_set(data, col, col_value)
            
            # 計算信息增益
            p = len(set1) / len(data)
            info_gain = current_entropy - p*_cal_entropy(set1) - (1-p)*_cal_entropy(set2)            
            if info_gain > best_gain and len(set1) > 0 and len(set2) > 0:
                best_gain = info_gain
                best_criteria = (col, col_value)
                best_sets = (set1, set2)
    
    # 創建子分支
    if best_gain > 0:
        trueBranch = build_tree(best_sets[0])
        falseBranch = build_tree(best_sets[1])
        return Node(col=best_criteria[0], value=best_criteria[1], 
                    tb=trueBranch, fb=falseBranch)
    else:
        return Node(results=_label_counts(data))
    
def predict(new_sample, tree):
    """ 對新的樣本進行預測 """
    if tree.results != None:
        return tree.results
    else:
        value = new_sample[tree.col]
        branch = None
        if isinstance(value, int) or isinstance(value, float):
            if value > tree.value:  
                branch = tree.true_branch
            else:
                branch = tree.false_branch
        else:
            if value == tree.value:
                branch = tree.true_branch
            else:
                branch = tree.false_branch
    return predict(new_sample, branch)

def prune(tree, mingain):
    """ 剪枝 
    tree:   
        訓練好的樹
    mingain:
        剪枝的閾值, 即兩個葉子節點合併前後的熵之差小於mingain, 則將兩個葉子節點合併
    """
    # 如果分支不是葉子節點, 則對其進行剪枝操作
    if tree.true_branch.results == None:
        prune(tree.true_branch, mingain)
    if tree.false_branch.results == None:
        prune(tree.false_branch, mingain)
        
    # 如果兩個節點都是葉子節點, 考慮是否將其合併
    if tree.true_branch.results != None and tree.false_branch.results != None:
        # 構造合併後的數據集
        true_branch, false_branch = [], []
        for value, count in tree.true_branch.results.items():
            true_branch += [[value]] * count
        for value, count in tree.false_branch.results.items():
            false_branch += [[value]] * count
                            
        # 檢查熵的減少情況
        delta = _cal_entropy(true_branch + false_branch) -\
                            _cal_entropy(true_branch) - _cal_entropy(false_branch)
        if delta < mingain:
            tree.tb, tree.fb = None, None
            tree.results = _label_counts(true_branch + false_branch)
    return tree
    
if __name__ == "__main__":
    data = load_data()
    tree = build_tree(data)
    prediction = predict(["(direct)", "USA", "yes", 5], tree)
    prune_tree = prune(tree, 0.1)


發佈了139 篇原創文章 · 獲贊 141 · 訪問量 45萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章