1. 決策樹原理篇
ID3算法:http://blog.csdn.net/zk_j1994/article/details/74066406
C4.5算法:http://blog.csdn.net/zk_j1994/article/details/74560278
CART算法:http://blog.csdn.net/zk_j1994/article/details/74606412
2. 決策樹實現
決策樹實現的算法步驟:
1)計算當前數據集的熵;
2)確定當前節點的劃分屬性,具體就是計算當前數據中每一個屬性的最大信息增益,選取信息增益最大的列作爲劃分屬性。
3)遞歸建樹,回到1),直到最大信息增益等於0
# -*- coding: utf-8 -*-
# 基於CART的決策樹算法實現
from math import log2
class Node:
def __init__(self, col=-1, value=None, results=None,
tb=None, fb=None):
self.col = col # 劃分data的屬性(列)
self.value = value # col中的某個值
self.results = results # 保存當前分支的結果, 除葉子節點外, 其他節點均爲None
self.true_branch = tb # true分支
self.false_branch = fb # false分支
def load_data():
data=[['slashdot','USA','yes',18,'None'],
['google','France','yes',23,'Premium'],
['digg','USA','yes',24,'Basic'],
['kiwitobes','France','yes',23,'Basic'],
['google','UK','no',21,'Premium'],
['(direct)','New Zealand','no',12,'None'],
['(direct)','UK','no',21,'Basic'],
['google','USA','no',24,'Premium'],
['slashdot','France','yes',19,'None'],
['digg','USA','no',18,'None'],
['google','UK','no',18,'None'],
['kiwitobes','UK','no',19,'None'],
['digg','New Zealand','yes',12,'Basic'],
['slashdot','UK','no',21,'None'],
['google','UK','yes',18,'Basic'],
['kiwitobes','France','yes',19,'Basic']]
return data
def _divide_set(data, column, value):
""" 對data在column上按照value拆分, 能夠處理數值型數據和標稱型數據 """
split_function = None
if isinstance(value, int) or isinstance(value, float):
split_function = lambda row: row[column] >= value
else:
split_function = lambda row: row[column] == value
# 將數據集拆分爲兩個集合, 並返回
set1 = [row for row in data if split_function(row)]
set2 = [row for row in data if not split_function(row)]
return (set1, set2)
def _label_counts(data):
""" 對data最後一列進行統計 """
result = {}
for row in data:
r = row[len(row) - 1]
result[r] = result.get(r, 0) + 1
return result
def _cal_entropy(data):
""" 計算當前數據集的熵 """
label_count = _label_counts(data)
counts = sum(label_count.values())
entropy = 0
for key in label_count.keys():
p = label_count[key] / counts
entropy += -(p * log2(p))
return entropy
def build_tree(data):
""" 遞歸建樹 """
if len(data) == 0: return Node()
# 計算當前data的熵
current_entropy = _cal_entropy(data)
# 定義一些變量以記錄最佳拆分條件
best_gain = 0.0
best_criteria = None
best_sets = None
column_count = len(data[0]) - 1
for col in range(column_count):
# 統計當前列col中所有不同的值
col_values = set()
for row in data:
col_values.add(row[col])
# 對當前列col進行劃分
for col_value in col_values:
set1, set2 = _divide_set(data, col, col_value)
# 計算信息增益
p = len(set1) / len(data)
info_gain = current_entropy - p*_cal_entropy(set1) - (1-p)*_cal_entropy(set2)
if info_gain > best_gain and len(set1) > 0 and len(set2) > 0:
best_gain = info_gain
best_criteria = (col, col_value)
best_sets = (set1, set2)
# 創建子分支
if best_gain > 0:
trueBranch = build_tree(best_sets[0])
falseBranch = build_tree(best_sets[1])
return Node(col=best_criteria[0], value=best_criteria[1],
tb=trueBranch, fb=falseBranch)
else:
return Node(results=_label_counts(data))
def predict(new_sample, tree):
""" 對新的樣本進行預測 """
if tree.results != None:
return tree.results
else:
value = new_sample[tree.col]
branch = None
if isinstance(value, int) or isinstance(value, float):
if value > tree.value:
branch = tree.true_branch
else:
branch = tree.false_branch
else:
if value == tree.value:
branch = tree.true_branch
else:
branch = tree.false_branch
return predict(new_sample, branch)
def prune(tree, mingain):
""" 剪枝
tree:
訓練好的樹
mingain:
剪枝的閾值, 即兩個葉子節點合併前後的熵之差小於mingain, 則將兩個葉子節點合併
"""
# 如果分支不是葉子節點, 則對其進行剪枝操作
if tree.true_branch.results == None:
prune(tree.true_branch, mingain)
if tree.false_branch.results == None:
prune(tree.false_branch, mingain)
# 如果兩個節點都是葉子節點, 考慮是否將其合併
if tree.true_branch.results != None and tree.false_branch.results != None:
# 構造合併後的數據集
true_branch, false_branch = [], []
for value, count in tree.true_branch.results.items():
true_branch += [[value]] * count
for value, count in tree.false_branch.results.items():
false_branch += [[value]] * count
# 檢查熵的減少情況
delta = _cal_entropy(true_branch + false_branch) -\
_cal_entropy(true_branch) - _cal_entropy(false_branch)
if delta < mingain:
tree.tb, tree.fb = None, None
tree.results = _label_counts(true_branch + false_branch)
return tree
if __name__ == "__main__":
data = load_data()
tree = build_tree(data)
prediction = predict(["(direct)", "USA", "yes", 5], tree)
prune_tree = prune(tree, 0.1)