機器學習實戰 —— 決策樹(sklearn api)

代碼

import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn import tree
from sklearn.externals.six import StringIO

# pip install pydotplus
# pip install graphviz
import pydotplus

# Graphviz瞎子地址:http://www.graphviz.org/download/
import os
os.environ["PATH"] += os.pathsep + 'D:/program files (x86)/Graphviz2.38/bin'


def loadData():
    """
    加載文件,生成特徵集和目標值集
    :return:
    """
    # 加載文件
    with open('lenses.txt') as fr:
        # 處理文件,去掉每行兩頭的空白符,以\t分隔每個數據
        lenses = [inst.strip().split('\t') for inst in fr.readlines()]

    # 提取每組數據的類別,保存在列表裏
    lenses_targt = []
    for each in lenses:
        # 存儲Label到lenses_targt中
        lenses_targt.append([each[-1]])

    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']

    # 保存lenses數據的字典,用於生成pandas
    lenses_dict = {}
    # 提取信息,生成字典
    for each_label in lensesLabels:
        # 保存lenses數據的臨時列表
        lenses_list = []
        for each in lenses:
            # index方法用於從列表中找出某個值第一個匹配項的索引位置
            lenses_list.append(each[lensesLabels.index(each_label)])
        lenses_dict[each_label] = lenses_list
    # 生成pandas.DataFrame用於對象的創建
    lenses_pd = pd.DataFrame(lenses_dict)
    print(lenses_targt)
    print(lenses_pd)

    return lenses_pd, lenses_targt


def dataEncoder(data_pd):
    le = LabelEncoder()
    # 爲每一列序列化
    for col in data_pd.columns:
        # fit_transform()幹了兩件事:fit找到數據轉換規則,並將數據標準化
        # transform()直接把轉換規則拿來用,需要先進行fit
        # transform函數是一定可以替換爲fit_transform函數的,fit_transform函數不能替換爲transform函數
        data_pd[col] = le.fit_transform(data_pd[col])
    # 打印歸一化的結果
    print(data_pd)


def createTree(data_pd, labels):
    # 創建DecisionTreeClassifier()類
    clf = tree.DecisionTreeClassifier(criterion='entropy', max_depth=4)
    # 使用數據構造決策樹
    # fit(X,y):Build a decision tree classifier from the training set(X,y)
    # 所有的sklearn的API必須先fit
    clf = clf.fit(data_pd.values.tolist(), labels)
    return clf


def exportTree(clf, feature_names):
    # 保存樹
    with open("lenses.dot", 'w') as f:
        tree.export_graphviz(clf, out_file=f)

    # 打印樹
    dot_data = StringIO()
    tree.export_graphviz(clf, out_file=dot_data,
                              feature_names=feature_names,
                              class_names=clf.classes_,
                              filled=True, rounded=True,
                              special_characters=True)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("tree.pdf")


def main():
    # 生成數據集和目標值集
    data_pd, targts = loadData()
    # 數據編碼,序列化
    dataEncoder(data_pd)
    # 生成樹
    tree = createTree(data_pd, targts)
    # 保存樹、打印樹
    exportTree(tree, data_pd.keys())

    # 預測
    print(tree.predict([[1, 1, 1, 0]]))


if __name__ == '__main__':
    main()

運行結果

[['no lenses'], ['soft'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['no lenses'], ['no lenses'], ['no lenses'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['no lenses']]
           age astigmatic prescript tearRate
0        young         no     myope  reduced
1        young         no     myope   normal
2        young        yes     myope  reduced
3        young        yes     myope   normal
4        young         no     hyper  reduced
5        young         no     hyper   normal
6        young        yes     hyper  reduced
7        young        yes     hyper   normal
8          pre         no     myope  reduced
9          pre         no     myope   normal
10         pre        yes     myope  reduced
11         pre        yes     myope   normal
12         pre         no     hyper  reduced
13         pre         no     hyper   normal
14         pre        yes     hyper  reduced
15         pre        yes     hyper   normal
16  presbyopic         no     myope  reduced
17  presbyopic         no     myope   normal
18  presbyopic        yes     myope  reduced
19  presbyopic        yes     myope   normal
20  presbyopic         no     hyper  reduced
21  presbyopic         no     hyper   normal
22  presbyopic        yes     hyper  reduced
23  presbyopic        yes     hyper   normal
    age  astigmatic  prescript  tearRate
0     2           0          1         1
1     2           0          1         0
2     2           1          1         1
3     2           1          1         0
4     2           0          0         1
5     2           0          0         0
6     2           1          0         1
7     2           1          0         0
8     0           0          1         1
9     0           0          1         0
10    0           1          1         1
11    0           1          1         0
12    0           0          0         1
13    0           0          0         0
14    0           1          0         1
15    0           1          0         0
16    1           0          1         1
17    1           0          1         0
18    1           1          1         1
19    1           1          1         0
20    1           0          0         1
21    1           0          0         0
22    1           1          0         1
23    1           1          0         0
['hard']

Process finished with exit code 0

lenses.dot

digraph Tree {
node [shape=box] ;
0 [label="X[3] <= 0.5\nentropy = 1.326\nsamples = 24\nvalue = [4, 15, 5]"] ;
1 [label="X[1] <= 0.5\nentropy = 1.555\nsamples = 12\nvalue = [4, 3, 5]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="X[2] <= 0.5\nentropy = 0.65\nsamples = 6\nvalue = [0, 1, 5]"] ;
1 -> 2 ;
3 [label="entropy = 0.0\nsamples = 3\nvalue = [0, 0, 3]"] ;
2 -> 3 ;
4 [label="X[0] <= 0.5\nentropy = 0.918\nsamples = 3\nvalue = [0, 1, 2]"] ;
2 -> 4 ;
5 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]"] ;
4 -> 5 ;
6 [label="entropy = 1.0\nsamples = 2\nvalue = [0, 1, 1]"] ;
4 -> 6 ;
7 [label="X[2] <= 0.5\nentropy = 0.918\nsamples = 6\nvalue = [4, 2, 0]"] ;
1 -> 7 ;
8 [label="X[0] <= 1.5\nentropy = 0.918\nsamples = 3\nvalue = [1, 2, 0]"] ;
7 -> 8 ;
9 [label="entropy = 0.0\nsamples = 2\nvalue = [0, 2, 0]"] ;
8 -> 9 ;
10 [label="entropy = 0.0\nsamples = 1\nvalue = [1, 0, 0]"] ;
8 -> 10 ;
11 [label="entropy = 0.0\nsamples = 3\nvalue = [3, 0, 0]"] ;
7 -> 11 ;
12 [label="entropy = 0.0\nsamples = 12\nvalue = [0, 12, 0]"] ;
0 -> 12 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
}

樹圖如下
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章