1.決策樹原理
決策樹算法重點就在於“決策”和“樹”這兩個概念,顧名思義決策樹是基於樹結構來進行決策的,這也恰恰是人們在遇到問題時進行問題梳理的一種很自然的處理機制。
決策樹的目標是建立分類和迴歸模型,核心目標是決策樹的生長和決策樹的修剪。
對於決策樹的生長算法有:ID3,C5.0,CART,CHAID,QUEST等;
對於決策樹的修剪方法有:預剪枝,後剪枝。
2.決策樹優缺點
優點:
- 不需要預處理,不需要提前歸一化,處理缺失值;
- 既可以處理離散值也可以處理連續值。很多算法只是專注於離散值或者連續值;
- 簡單直觀,生成的決策樹很直觀;
- 使用決策樹預測的代價是O(log2m)O(log2m)。 m爲樣本數;
- 可以處理多維度輸出的分類問題;
- 相比於神經網絡之類的黑盒分類模型,決策樹在邏輯上可以得到很好的解釋;
- 可以交叉驗證的剪枝來選擇模型,從而提高泛化能力;
- 對於異常點的容錯能力好,健壯性高。
缺點: - 決策樹算法非常容易過擬合,導致泛化能力不強。可以通過設置節點最少樣本數量和限制決策樹深度來改進;
- 決策樹會因爲樣本發生一點點的改動,就會導致樹結構的劇烈改變。這個可以通過集成學習之類的方法解決;
- 尋找最優的決策樹是一個NP難的問題,我們一般是通過啓發式方法,容易陷入局部最優。可以通過集成學習之類的方法來改善;
- 有些比較複雜的關係,決策樹很難學習,比如異或。這個就沒有辦法了,一般這種關係可以換神經網絡分類方法來解決;
- 如果某些特徵的樣本比例過大,生成決策樹容易偏向於這些特徵。這個可以通過調節樣本權重來改善。
3.CART算法
縱使決策樹的生成算法有很多,但是scikit-learn決策樹算法類庫內部實現是使用了調優過的CART樹算法,既可以做分類,又可以做迴歸。分類迴歸樹(Classification And Regression Tree, CART)是由美國斯坦福大學和加州大學伯克利分校的佈雷曼(Breiman) 等人於1984年提出的,同年他們出版了相關專著Classification and Regression Trees。
CART算法也有樹的生成和剪枝兩部分,對於樹的生成採用的標準主要是:基尼係數(分類),方差(迴歸);對於樹的剪枝採用的標準主要是是:MCCP算法(最小代價複雜性修剪法)。對於這兩部分的理論介紹這裏不再給出,隨便找一本介紹CART樹算法的書都有相應的介紹,例如:鏈接1,鏈接2……
4.CART算法實現
scikit-learn決策樹算法類庫中,分類決策樹的類對應的是DecisionTreeClassifier,而回歸決策樹的類對應的是DecisionTreeRegressor。
對於這兩者來說,參數定義幾乎完全相同,但是意義不全相同。下面就對DecisionTreeClassifier和DecisionTreeRegressor的重要參數做一個總結,重點比較兩者參數使用的不同點和調參的注意點。
DecisionTreeClassifier
DecisionTreeClassifier(criterion=’gini’, splitter=’best’, max_depth=None,
min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0,
max_features=None, random_state=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None,
presort=False)
DecisionTreeRegressor
DecisionTreeRegressor(criterion=’mse’, splitter=’best’, max_depth=None,
min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0,
max_features=None, random_state=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None, presort=False)
5. 應用實例–泰坦尼克號數據集
5.1 數據集獲取
鏈接:https://pan.baidu.com/s/13Qd_qoR22B4VRvUIGm296w 提取碼:x4ic
5.2 數據描述
特徵 | 描述 |
---|---|
Survived | 0代表死亡,1代表存活 |
Pclass | 乘客所持票類,有三種值(1,2,3) |
Name | 乘客姓名 |
Sex | 乘客性別 |
Age | 乘客年齡(有缺失) |
SibSp | 乘客兄弟姐妹/配偶的個數(整數值) |
Parch | 乘客父母/孩子的個數(整數值) |
Ticket | 票號(字符串) |
Fare | 乘客所持票的價格(浮點數,0-500不等) |
Cabin | 乘客所在船艙(有缺失) |
Embark | 乘客登船港口S、C、Q(有缺失) |
5.3 代碼實例
1.導入所需要的庫
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
2.導入數據集,探索數據
data = pd.read_csv("E:/data/titianic.csv",index_col = 0)
data.head()
data.info()
3.對數據集進行預處理
#刪除缺失值過多的列,和觀察判斷來說和預測的y沒有關係的列
data.drop(["Cabin","Name","Ticket"],inplace=True,axis=1)
#處理缺失值,對缺失值較多的列進行填補,有一些特徵只確實一兩個值,可以採取直接刪除記錄的方法
data["Age"] = data["Age"].fillna(data["Age"].mean())
data = data.dropna()
#將分類變量轉換爲數值型變量
#將二分類變量轉換爲數值型變量
#astype能夠將一個pandas對象轉換爲某種類型,和apply(int(x))不同,astype可以將文本類轉換爲數字,用這個方式可以很便捷地將二分類特徵轉換爲0~1
data["Sex"] = (data["Sex"]== "male").astype("int")
#將三分類變量轉換爲數值型變量
labels = data["Embarked"].unique().tolist()
data["Embarked"] = data["Embarked"].apply(lambda x: labels.index(x))
#查看處理後的數據集
data.head()
4.提取X和Y,劃分數據集
X = data.iloc[:,data.columns != "Survived"]
y = data.iloc[:,data.columns == "Survived"]
from sklearn.model_selection import train_test_split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,y,test_size=0.3)
#修正測試集和訓練集的索引
for i in [Xtrain, Xtest, Ytrain, Ytest]:
i.index = range(i.shape[0])
#查看分好的訓練集和測試集
Xtrain.head()
5.首次嘗試,粗略查看結果
clf = DecisionTreeClassifier(random_state=42)
clf = clf.fit(Xtrain, Ytrain)
score_ = clf.score(Xtest, Ytest)
print(score_)
#嘗試一下10折交叉驗證的平均分和方差
score = cross_val_score(clf,X,y,cv=10)
print(score)
print(score.mean())
print(score.std())
6.在不同max_depth下觀察模型的擬合效果
#當然這裏你也可以嘗試其他參數,不一定試驗max_depth
tr = []
te = []
for i in range(10):
clf = DecisionTreeClassifier(random_state=42
,max_depth=i+1
,criterion="entropy"
)
clf = clf.fit(Xtrain, Ytrain)
score_tr = clf.score(Xtrain,Ytrain)
score_te = clf.score(Xtest, Ytest)
tr.append(score_tr)
te.append(score_te)
print("train:{}".format(max(tr)))
print("test:{}".format(max(te)))
plt.plot(range(1,11),tr,color="red",label="train")
plt.plot(range(1,11),te,color="blue",label="test")
plt.xticks(range(1,11))
plt.legend()
plt.show()
7.用網格搜索調整參數
parameters = {'splitter':('best','random')
,'criterion':("gini","entropy")
,"max_depth":np.arange(1,10)
,'min_samples_leaf':np.arange(1,50,5)
,'min_impurity_decrease':np.linspace(0,0.5,20)
}
clf = DecisionTreeClassifier(random_state=42)
GS = GridSearchCV(clf, parameters, cv=10)
GS.fit(Xtrain,Ytrain)
GS.best_params_
GS.best_score_