《統計學習方法》 第五章 決策樹及sklearn包中決策樹算法的使用

本文內容參考李航老師的《統計學習方法》及其配套課件

python實現原文代碼作者:https://github.com/wzyonggege/statistical-learning-method

sklearn包中決策樹算法的使用資料鏈接:https://scikit-learn.org/dev/modules/tree.html

 

決策樹的簡介

        決策樹Decision Trees是一種用於分類迴歸( classification and regression)的無監督學習方法。目標是創建一個模型,從數據特徵中學習簡單的決策規則來預測目標變量的值。

 

例如,在下面的示例中,決策樹通過if-then-else的決策規則來學習數據從而估測到一個正弦曲線。樹越深,決策規則越複雜,模型對數據的擬合效果就越好。

關於決策樹的一些優點:

  • 便於理解和解釋。樹可以可視化。
  • 訓練需要的數據少,其他的模型方法通常需要數據規範化 ,比如創建虛擬變量並且刪除缺失值(注意,這個模型不支持缺失值)。
  • 訓練模型的時間複雜度是參與訓練數據點的對數值(訓練決策樹的數據點的數據量導致了)。
  • 使用白盒模型。如果某種給定的情況在模型中是可以觀測的,那麼就可以通過布爾邏輯解釋這種情況。相比之下,在黑盒模型中(比如,人的神經網絡)的結果就難以解釋。
  • 可以用數值統計測試來驗證模型。使解釋驗證該模型的可靠性成爲可能。
  • 即使該模型假設的結果與真實模型所提供的數據有出入,但模型的效果仍舊很好。

決策樹的一些缺點: 

  • 容易過擬合
  • 決策樹可能不穩定,數據中微小的變化都可能導致生成完全不同的樹。在集成中使用決策樹可以緩解這個問題。
  • 學習最優決策樹的問題通常是一個NP難問題。因此,實際的決策樹學習算法是基於啓發式算法。
  • 有些概念難以學習,決策樹不容易表達。例如XOR,奇偶或者複用器的問題

 

決策樹由結點(node) 和有向邊(directed edge) 組成。 結點有兩種類型: 內部結點(internal node) 和葉結點(leaf node)。內部結點表示一個特徵或屬性, 葉結點表示一個類。用決策樹分類, 從根結點開始, 對實例的某一特徵進行測試, 根據測試結果, 將實例分配到其子結點; 這時, 每一個子結點對應着該特徵的一個取值。 如此遞歸地對實例進行測試並分配, 直至達到葉結點。 最後將實例分到葉結點的類中。

例如 上圖就是一個決策過程生成的分類決策樹,能想象一個女生要相親,相當於通過年齡,長相,收入和是否是公務員對男孩分爲兩個類別:見和不見。假設這個女孩對男孩的要求是:30歲以下,長相中等以上並且是高收入者,或者中等以上收入的公務員,那麼這個圖表示女孩的決策邏輯。其中綠色結點表示判斷條件,橙色結點表示決策結果,箭頭表示在一個判斷條件在不同情況下的決策路徑,圖中紅色箭頭表示了上面例子中女孩的決策過程。

 

決策樹的構造

訓練過程中構建這棵決策樹的時候要怎麼做呢?就是一個一個特徵屬性依次比較過去然後建立分支嗎?不是的,我們需要挑選最有代表性的特徵。特徵選擇在於選取對訓練數據具有分類能力的特徵。 這樣可以提高決策樹學習的效率。在構建決策樹的過程中,最重要的就是怎麼選取合適的特徵來構建它。如果選取不合理,可能會造成產生的決策樹過於龐大,提升程序的複雜度,此外也會造成決策樹的泛化性能降低。生成決策樹的算法通常有ID3、C4.5、CART。

信息增益

什麼是信息的不確定性?就是信息熵。我們給出信息熵的定義:

在熵H(P)越大時,表示隨機變量的不確定性越大。而熵越大,信息就越混亂,P(x)機率就越小。 

在決策樹根節點的最初,我們先假設信息熵爲1,表示我們的一無所知。到葉節點時假設信息熵爲0,表示非常確信。那麼使用決策樹決策的過程就是我們不斷減少信息熵的過程,直到它降爲0。我們的目標是希望信息熵能下降得快一點。這就涉及到決策樹的構建了,我們該怎麼構建才能使得這棵決策樹的信息熵在判斷分支中下降得最快呢?這就是信息增益。

接下來給出信息增益的算法:

 

sklearn包中決策樹算法的使用----分類 

與其他分類器一樣,  DecisionTreeClassifier將兩個數組作爲輸入:一個數組X,用[n_samples, n_features]存放訓練樣本;Y數組用[n_samples]來存放訓練樣本的類標籤。

#導入sklearn包
>>> from sklearn import tree
>>> X = [[0, 0], [1, 1]]
>>> Y = [0, 1]
>>> clf = tree.DecisionTreeClassifier()
>>> clf = clf.fit(X, Y)

接下來,模型可以預測樣本的類別。

>>> clf.predict([[2., 2.]])
array([1])

或者,可以預測每個類的概率,這個概率是葉子中同類訓練樣本的比例.predict_proba返回的是一個 n 行 k 列的數組, 第 i 行 第 j 列上的數值是模型預測 第 i 個預測樣本爲某個標籤的概率,並且每一行的概率和爲1。

>>> clf.predict_proba([[2., 2.]])
array([[0., 1.]])

所以結果表示預測[2.,2.]的標籤是0的概率是0,是1的概率是1。

 DecisionTreeClassifier既支持二分類(其標籤爲[-1,1]),也支持多分類([0, …, K-1])。

利用Iris數據集,我們可以構建如下樹:

>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>> X, y = load_iris(return_X_y=True)
>>> clf = tree.DecisionTreeClassifier()
>>> clf = clf.fit(X, y)

一旦經過訓練,可以用plot_tree 函數繪製樹,也可以通過graphviz將樹可視化。首先要安裝一下這個包,如果用conda管理包可以用指令 conda install python-graphviz安裝,我是先用系統安裝了一下再用python安裝。指令:brew install graphviz  // pip3 install graphviz

>>> import graphviz 
>>> dot_data = tree.export_graphviz(clf, out_file=None) 
>>> graph = graphviz.Source(dot_data) 
>>> graph.render("iris")

這樣就已經生成一個pdf文件

 export_graphviz也支持各種美化圖形。加入各種參數

dot_data = tree.export_graphviz(clf, out_file=None, 
                        feature_names=iris.feature_names,  
                        class_names=iris.target_names,  
                        filled=True, rounded=True)

graph=graphviz.Source(dot_data)
graph.render("iris")

接下來的示例用的是iris數據集。Iris 鳶尾花數據集是一個經典數據集,在統計學習和機器學習領域都經常被用作示例。數據集內包含 3 類(Iris Setosa,Iris Versicolour,Iris Virginica)共 150 條記錄,每類各 50 個數據,每條記錄都有 4 項特徵:花萼長度、花萼寬度、花瓣長度、花瓣寬度,可以通過這4個特徵預測鳶尾花卉屬於哪一品種。

基於iris數據集繪製決策樹

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree

# Parameters
n_classes = 3
plot_colors = "ryb"
plot_step = 0.02

# Load data
iris = load_iris()

for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3],
                                [1, 2], [1, 3], [2, 3]]):
    # We only take the two corresponding features
    #從四列數據中選取兩個特徵進行訓練
    X = iris.data[:, pair]
    y = iris.target

    # Train
    clf = DecisionTreeClassifier().fit(X, y)

    # Plot the decision boundary
    plt.subplot(2, 3, pairidx + 1)
    #subplot直接指定劃分方式和位置進行繪圖,2行3列排列圖片
    
    #繪製決策邊界,選擇最大值,最小值
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1

    # numpy.meshgrid()——生成網格點座標矩陣。numpy.arange()分割數
    xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
                         np.arange(y_min, y_max, plot_step))
    # tight_layout() 進行自動控制圖像佈局,通過參數pad, w_pad, h_pad設置佈局細節
    plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)

    
    # 按照第一個循環,把第一列花萼長度數據按h取等分,作爲行,然後複製多行,得到xx網格矩陣
    #把第二列的花萼寬度數據按h取等分,作爲列,複製多列,得到網格矩陣
    #np.c_是按列連接兩個矩陣,就是把兩矩陣左右相加,要求行數相等,類似於pandas中的merge()
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)

    #繪製等高線的,contour和contourf都是畫三維等高線圖的
    #不同點在於contour() 是繪製輪廓線,contourf()會填充輪廓。
    #matplotlib.cm是色彩映射函數。 
    cs = plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu)
       
    
    #橫縱座標label特徵名稱
    plt.xlabel(iris.feature_names[pair[0]])
    plt.ylabel(iris.feature_names[pair[1]])


    # Plot the training points 繪製每個類別的鳶尾花數據的散點圖
    for i, color in zip(range(n_classes), plot_colors):
        # 這裏的numpy.where()只有一個參數,返回條件爲True的索引
        #所以這裏會依次返回每種鳶尾花的樣本索引。
        idx = np.where(y == i)
        #取出樣本的第0列,第1列
        plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i],
                    cmap=plt.cm.RdYlBu, edgecolor='black', s=15)

plt.suptitle("Decision surface of a decision tree using paired features")
plt.legend(loc='lower right', borderpad=0, handletextpad=0)
plt.axis("tight")

plt.figure()
clf = DecisionTreeClassifier().fit(iris.data, iris.target)
plot_tree(clf, filled=True)
plt.show()

得到兩個圖 

 

上述代碼中其中,iris.data數據如下

 此時X= iris.data[:,pair],第一個循環中取iris.data數據中的第0列和第1列即iris.data[: , [0,1] ]

代碼中numpy.where(condition[,x,y])

參數:

condition : 數組,bool值

如果爲True, 則產生 x, 否則產生  y.

x, y : array_like, 可選

x與y的shape要相同,當condition中的值是true時返回x對應位置的值,false是返回y的

返回值:

out : ndarray 或ndarray 原組

①如果參數有condition,x和y,它們三個參數的shape是相同的。那麼,當condition中的值是true時返回x對應位置的值,false是返回y的。

②如果參數只有condition的話,返回值是condition中元素值爲true的位置索引,切是以元組形式返回,元組的元素是ndarray數組,表示位置的索引

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章