决策树算法实现

decision-tree.py

本文为 落魄陶陶 原创,转载请注明出处
数据来源及源码参见github

  • 学习并参考《机器学习实战》第三章
  • 主要使用Pandas库
  • decision-tree.py为基本算法实现,基于数据fish.xlsx

理解核心:

  1. 数据的有序程度以熵来表示,信息增益越大,表明对数据的划分越有效
  2. 遍历每个字段尝试对数据进行划分后计算信息增益,每次取信息增益最大的划分
  3. 如果划分后某个分组中都属于一类,停止划分,否则递归调用步骤2进一步划分

关键在于第2步,要明确,熵的衡量,每次都是以label列的有序程度来计算,
换句话说,根据Xi的不同取值对数据进行划分,其对应的分类组Y更加有序,说明这是更好的划分

import math

import pandas as pd


# 加载数据
def load_excel(file: str) -> pd.DataFrame:
    return pd.read_excel(file)


def load_csv(file: str, sep: str = ',') -> pd.DataFrame:
    return pd.read_csv(file, sep=sep, header=None)


# 计算熵 H=-∑p(xi)log(p(xi),2)
def calc_entropy(df: pd.DataFrame) -> float:
    total = df.shape[0]
    value_counts = df[df.columns[-1]].value_counts()
    entropy_items = value_counts. \
        apply(lambda x: x / total). \
        apply(lambda prob: prob * math.log2(prob))
    return -entropy_items.sum()


# 划分子集
def split_data_frame(df, col_name, val):
    return df[df[col_name] == val].drop(col_name, axis=1)


# 选择最好子集划分
def choose_best_feature(df: pd.DataFrame) -> str:
    columns = df.columns[:-1]
    best_entropy = calc_entropy(df)
    best_info_gain = 0.
    best_column = None
    for col in columns:
        values = df[col].unique()
        new_entropy = 0.
        for val in values:
            subset = split_data_frame(df, col, val)
            new_entropy += calc_entropy(subset)
        info_gain = best_entropy - new_entropy
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_column = col
    return best_column


# 创建决策树
def create_tree(df: pd.DataFrame):
    values = df[df.columns[-1]].unique()
    if values.size < 2:  # 所有label都相同,返回label
        return values[0]
    if df.shape[0] == 2:  # df中只有最后一列数据和label,不可进一步划分,统计label中数量最多的为最终label
        return df[df.columns[-1]].value_counts(ascending=False).values[0]
    best_column = choose_best_feature(df)
    tree = {best_column: {}}
    values_of_column = df[best_column].unique()
    for val in values_of_column:
        tree[best_column][val] = create_tree(split_data_frame(df, best_column, val))
    return tree


if __name__ == '__main__':
    # df = load_excel('fish.xlsx')
    df = load_csv('lenses.txt', '\t')
    # entropy = calc_entropy(df)
    # column = choose_best_feature(df)
    tree = create_tree(df)
    print(tree)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章