decision-tree.py
本文爲 落魄陶陶 原創,轉載請註明出處
數據來源及源碼參見github
- 學習並參考《機器學習實戰》第三章
- 主要使用Pandas庫
- decision-tree.py爲基本算法實現,基於數據fish.xlsx
理解核心:
- 數據的有序程度以熵來表示,信息增益越大,表明對數據的劃分越有效
- 遍歷每個字段嘗試對數據進行劃分後計算信息增益,每次取信息增益最大的劃分
- 如果劃分後某個分組中都屬於一類,停止劃分,否則遞歸調用步驟2進一步劃分
關鍵在於第2步,要明確,熵的衡量,每次都是以label列的有序程度來計算,
換句話說,根據Xi的不同取值對數據進行劃分,其對應的分類組Y更加有序,說明這是更好的劃分
import math
import pandas as pd
# 加載數據
def load_excel(file: str) -> pd.DataFrame:
return pd.read_excel(file)
def load_csv(file: str, sep: str = ',') -> pd.DataFrame:
return pd.read_csv(file, sep=sep, header=None)
# 計算熵 H=-∑p(xi)log(p(xi),2)
def calc_entropy(df: pd.DataFrame) -> float:
total = df.shape[0]
value_counts = df[df.columns[-1]].value_counts()
entropy_items = value_counts. \
apply(lambda x: x / total). \
apply(lambda prob: prob * math.log2(prob))
return -entropy_items.sum()
# 劃分子集
def split_data_frame(df, col_name, val):
return df[df[col_name] == val].drop(col_name, axis=1)
# 選擇最好子集劃分
def choose_best_feature(df: pd.DataFrame) -> str:
columns = df.columns[:-1]
best_entropy = calc_entropy(df)
best_info_gain = 0.
best_column = None
for col in columns:
values = df[col].unique()
new_entropy = 0.
for val in values:
subset = split_data_frame(df, col, val)
new_entropy += calc_entropy(subset)
info_gain = best_entropy - new_entropy
if info_gain > best_info_gain:
best_info_gain = info_gain
best_column = col
return best_column
# 創建決策樹
def create_tree(df: pd.DataFrame):
values = df[df.columns[-1]].unique()
if values.size < 2: # 所有label都相同,返回label
return values[0]
if df.shape[0] == 2: # df中只有最後一列數據和label,不可進一步劃分,統計label中數量最多的爲最終label
return df[df.columns[-1]].value_counts(ascending=False).values[0]
best_column = choose_best_feature(df)
tree = {best_column: {}}
values_of_column = df[best_column].unique()
for val in values_of_column:
tree[best_column][val] = create_tree(split_data_frame(df, best_column, val))
return tree
if __name__ == '__main__':
# df = load_excel('fish.xlsx')
df = load_csv('lenses.txt', '\t')
# entropy = calc_entropy(df)
# column = choose_best_feature(df)
tree = create_tree(df)
print(tree)