決策樹算法實現

decision-tree.py

本文爲 落魄陶陶 原創,轉載請註明出處
數據來源及源碼參見github

  • 學習並參考《機器學習實戰》第三章
  • 主要使用Pandas庫
  • decision-tree.py爲基本算法實現,基於數據fish.xlsx

理解核心:

  1. 數據的有序程度以熵來表示,信息增益越大,表明對數據的劃分越有效
  2. 遍歷每個字段嘗試對數據進行劃分後計算信息增益,每次取信息增益最大的劃分
  3. 如果劃分後某個分組中都屬於一類,停止劃分,否則遞歸調用步驟2進一步劃分

關鍵在於第2步,要明確,熵的衡量,每次都是以label列的有序程度來計算,
換句話說,根據Xi的不同取值對數據進行劃分,其對應的分類組Y更加有序,說明這是更好的劃分

import math

import pandas as pd


# 加載數據
def load_excel(file: str) -> pd.DataFrame:
    return pd.read_excel(file)


def load_csv(file: str, sep: str = ',') -> pd.DataFrame:
    return pd.read_csv(file, sep=sep, header=None)


# 計算熵 H=-∑p(xi)log(p(xi),2)
def calc_entropy(df: pd.DataFrame) -> float:
    total = df.shape[0]
    value_counts = df[df.columns[-1]].value_counts()
    entropy_items = value_counts. \
        apply(lambda x: x / total). \
        apply(lambda prob: prob * math.log2(prob))
    return -entropy_items.sum()


# 劃分子集
def split_data_frame(df, col_name, val):
    return df[df[col_name] == val].drop(col_name, axis=1)


# 選擇最好子集劃分
def choose_best_feature(df: pd.DataFrame) -> str:
    columns = df.columns[:-1]
    best_entropy = calc_entropy(df)
    best_info_gain = 0.
    best_column = None
    for col in columns:
        values = df[col].unique()
        new_entropy = 0.
        for val in values:
            subset = split_data_frame(df, col, val)
            new_entropy += calc_entropy(subset)
        info_gain = best_entropy - new_entropy
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_column = col
    return best_column


# 創建決策樹
def create_tree(df: pd.DataFrame):
    values = df[df.columns[-1]].unique()
    if values.size < 2:  # 所有label都相同,返回label
        return values[0]
    if df.shape[0] == 2:  # df中只有最後一列數據和label,不可進一步劃分,統計label中數量最多的爲最終label
        return df[df.columns[-1]].value_counts(ascending=False).values[0]
    best_column = choose_best_feature(df)
    tree = {best_column: {}}
    values_of_column = df[best_column].unique()
    for val in values_of_column:
        tree[best_column][val] = create_tree(split_data_frame(df, best_column, val))
    return tree


if __name__ == '__main__':
    # df = load_excel('fish.xlsx')
    df = load_csv('lenses.txt', '\t')
    # entropy = calc_entropy(df)
    # column = choose_best_feature(df)
    tree = create_tree(df)
    print(tree)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章