決策樹
什麼是決策樹
比方說我們在招聘一個機器學習算法工程師的時候,會依照這樣的流程進行逐層的評選,從而達到一個樹形結構的決策過程。而在這棵樹中,它的深度爲3.最多通過3次判斷,就能將我們的數據進行一個相應的分類。我們在這裏每一個節點都可以用yes或者no來回答的問題,實際上我們真實的數據很多內容都是一個具體的數值。對於這些具體的數值,決策樹是怎麼表徵的呢?我們先使用scikit-learn封裝的決策樹算法進行一下具體的分類。然後通過分類的結果再深入的認識一下決策樹。這裏我依然先加載鳶尾花數據集。
import numpy as np import matplotlib.pyplot as plt from sklearn import datasets if __name__ == "__main__": iris = datasets.load_iris() # 保留鳶尾花數據集的後兩個特徵 X = iris.data[:, 2:] y = iris.target plt.scatter(X[y == 0, 0], X[y == 0, 1]) plt.scatter(X[y == 1, 0], X[y == 1, 1]) plt.scatter(X[y == 2, 0], X[y == 2, 1]) plt.show()
運行結果
現在我們使用scikit-learn中的決策樹來進行分類
import numpy as np import matplotlib.pyplot as plt from sklearn import datasets from sklearn.tree import DecisionTreeClassifier from matplotlib.colors import ListedColormap if __name__ == "__main__": iris = datasets.load_iris() # 保留鳶尾花數據集的後兩個特徵 X = iris.data[:, 2:] y = iris.target # plt.scatter(X[y == 0, 0], X[y == 0, 1]) # plt.scatter(X[y == 1, 0], X[y == 1, 1]) # plt.scatter(X[y == 2, 0], X[y == 2, 1]) # plt.show() dt_clf = DecisionTreeClassifier(max_depth=2, criterion='entropy') dt_clf.fit(X, y) def plot_decision_boundary(model, axis): # 繪製不規則決策邊界 x0, x1 = np.meshgrid( np.linspace(axis[0], axis[1], int((axis[1] - axis[0]) * 100)).reshape(-1, 1), np.linspace(axis[2], axis[3], int((axis[3] - axis[2]) * 100)).reshape(-1, 1) ) X_new = np.c_[x0.ravel(), x1.ravel()] y_predict = model.predict(X_new) zz = y_predict.reshape(x0.shape) custom_cmap = ListedColormap(['#EF9A9A', '#FFF59D', '#90CAF9']) plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap) plot_decision_boundary(dt_clf, axis=[0.5, 7.5, 0, 3]) plt.scatter(X[y == 0, 0], X[y == 0, 1]) plt.scatter(X[y == 1, 0], X[y == 1, 1]) plt.scatter(X[y == 2, 0], X[y == 2, 1]) plt.show()
運行結果
這就是決策樹得到的決策邊界。通過這個圖,我們來畫一個決策樹,看看它是如何逐層決策的
這就是決策樹在面對屬性是一種數值特徵的時候是怎樣處理的。在這裏我們可以看到,在每一個節點上,它選擇某一個維度,以及和這個維度相應的閾值,比如說在根節點的時候選的是x這個維度和2.4這個閾值。看我們的數據是大於等於2.4還是小於2.4分成兩支。在右分支的子節點,選擇了y這個維度和1.8這個閾值,看數據點到了這個節點的話是小於1.8還是大於等於1.8來進行分類。
首先決策樹是一個非參數學習算法。其次決策樹可以解決分類問題,而且可以天然的解決多分類問題。不像邏輯迴歸和SVM需要使用OvR或者OvO才能解決多分類問題。同時決策樹也可以解決迴歸問題,我們可以用最終落在這個葉子節點中的所有的樣本數據的平均值來當作迴歸問題的預測值。決策樹算法也是具有非常好的可解釋性。比如說我們在給用戶的信用進行評級,比如說他的信用卡超過3次拖延支付,並且他的駕照平均每年都會被扣去8分,滿足這樣的條件,他的信用評級就會評爲C級。我們可以非常容易的描述出來把樣本數據分成某一個類的依據。我們現在的問題就是每個節點到底在哪個維度做劃分?我們的數據複雜起來的話可能有成百上千個維度。如果我們選好了一個維度,那麼我們要在某個維度的哪個值上做劃分?這些都是構建決策樹的關鍵問題。