決策樹綜合

一、概念

算法 特徵
ID3 使用 信息增益 度量不純度;可處理 離散型 數據;可用於 分類;每個節點衍生出 多個分支
C4.5 使用 信息增益率 度量不純度;可處理 離散型/連續型 數據;可用於 分類;每個節點衍生出 多個分支
CART 使用 基尼係數 度量不純度;可處理 離散型/連續型 數據;可用於 分類/迴歸;每個節點衍生出 兩個分支

依據信息論的定義,信息的混亂程度由熵 (Entropy) 給出。假定樣本數據 XX 中有 NN 種類別,則H(X)=j=1NpjlogpjH(X)=-\sum_{j=1}^N p_j \log p_j 信息增益 (Information Gain) 計算一個節點中的數據劃分前後的熵差值,衡量不純度減小的程度:info_gain=H(X)iXiXH(Xi)info\_gain=H(X)-\sum_i \frac{|X_i|}{|X|}H(X_i)信息增益的缺點是顯而易見的,當某一特徵(例如姓名)取值較多時,每一種取值下對應一條記錄,使用該特徵劃分能獲取極大的信息增益,但實際上訓練完成的算法泛化能力極差。因此基於信息增益的 ID3 算法僅適用於處理取值較少的離散型數據。爲應對此類情況,信息增益率 (Information Gain Ratio) 在信息增益的基礎上做調整:info_gain_ratio=info_gainH(A)info\_gain\_ratio=\frac{info\_gain}{H(A)}H(A)H(A) 代表屬性 AA 取值的信息熵。基尼係數 (Gini Index) 則是與熵相對的另一種不純度度量方式,公式如下:Gini(X)=1j=1Npj2Gini(X)=1-\sum_{j=1}^Np_j^2在 CART 算法中,我們希望最大化劃分前後的基尼增益:Gini_gain=Gini(X)iXiXGini(Xi)Gini\_gain=Gini(X)-\sum_i\frac{|X_i|}{|X|}Gini(X_i)學術界還有諸多其他類型的信息不純度度量方式,在此不多贅述。

二、算法

ID3

該算法在幾種決策樹算法中最爲簡單,以下僞代碼中包含了預剪枝過程(信息增益太小,或驗證集表現無法繼續提升),這一過程在 C4.5 和 CART 算法中同樣適用。關於 ID3 算法使用信息增益作爲不純度度量標準的缺陷上文中已說明。

Algorithm ID3(Node):
Input: Object Node containing sample data.
Output: N/A.
if the depth exceeds the claimed maximum depth then label Node with y and terminate the branch
if the samples in Node is of the same class y then label Node with y and terminate the branch
if there is no remaining attribute unused then label Node with the class y with the most samples and terminate the branch
for each unused attribute A do
  calculate information gain
 select the attribute A* that maximizes information gain as the branching attribute at Node
call prunning()    # code block to terminate the branch in advance, i.e. when the information gain is too small.
 segment the samples of Node into M fractions based on their values of A*
for each segmentation Di do
  initialize child node Node_i and feed Di to the node
  recursively call ID3(Node_i)

C4.5

爲避免 ID3 中特徵選取偏向於取值較多的特徵,C4.5 使用信息增益率作爲不純度的度量方式。同時,C4.5 增加了對連續型變量的二分法處理過程。二分法在於首先對特徵取值進行排序,而後依據相鄰數值的平均數生成一列二分閾值,從中挑出最佳劃分點。

Algorithm C4.5(Node):
Input: Object Node containing sample data.
Output: N/A.
if the depth exceeds the claimed maximum depth then label Node with y and terminate the branch
if the samples in Node is of the same class y then label Node with y and terminate the branch
if there is no remaining attribute unused then label Node with the class y with the most samples and terminate the branch
for each unused attribute A do
  if A is discrete then
   calculate information gain ratio
  else
   select the optimal threshold value that maximizes infomation gain ratio
 select the attribute A* that maximizes information gain ratio as the branching attribute at Node
call prunning()    # code block to terminate the branch in advance, i.e. when the information gain ratio is too small.
 segment the samples of Node into M fractions based on the branching principle
for each segmentation Di do
  initialize child node Node_i and feed Di to the node
  recursively call C4.5(Node_i)

CART

CART 算法與 C4.5 相比,對離散型變量同樣採用二分法處理,將樹的結構約束爲二叉樹,同時增加了對迴歸任務的處理步驟。
分類問題上,CART 使用基尼增益挑選最佳劃分點 (具體方法與 C4.5 類似):
ρ=argmaxρ[Gini(X)iGini(Xi)]\rho^*=\arg\max_\rho \big[Gini(X)-\sum_iGini(X_i)\big]迴歸問題上,CART 則使用最小二乘法:ρ=argminρ[xi<ρ(yiyˉxi<ρ)2+xiρ(yiyˉxiρ)2]\rho^*=\arg\min_\rho \big[\sum_{x_i<\rho}(y_i-\bar{y}_{x_i<\rho})^2+\sum_{x_i\ge\rho}(y_i-\bar{y}_{x_i \ge\rho})^2\big]

Algorithm CART(Node):
Input: Object Node containing sample data.
Output: N/A.
if the depth exceeds the claimed maximum depth then terminate the branch
if the samples in Node is of the same class or covers a range smaller than requirement then terminate the branch
if there is no remaining attribute unused then terminate the branch
for each unused attribute A do
  if A is discrete then
   select the optimal value that maximizes Gini gain or minimizes square values
  else
   select the optimal threshold value that maximizes Gini gain or minimizes square values
 select the attribute A* that maximizes Gini gain or minimizes square values as the branching attribute at Node
call prunning()    # code block to terminate the branch in advance, i.e. when the Gini gain is too small.
 segment the samples of Node into two fractions based on the branching principle
for each segmentation Di (i=1,2) do
  initialize child node Node_i and feed Di to the node
  recursively call CART(Node_i)

三、剪枝

爲防止決策樹算法過擬合,通常有預剪枝和後剪枝兩種處理方式。預剪枝通過設立提前停止條件,在生成枝葉時立即執行,也即上述僞代碼中的 prunning(),常見的條件有 “不純度降低少於閾值” 和 “無法繼續優化驗證集表現” 等;後剪枝則在決策樹生成完畢後進行修剪,通常而言也有兩種做法:“使用驗證集檢測無法提升準確度的節點”、“應用正則化思想結合樣本不純度和模型複雜度定義新的損失函數”。以 CART 分類爲例,第二種方法中的損失函數採取以下形式:L=iXiXGini(Xi)+αNL=\sum_i\frac{|X_i|}{|X|}Gini(X_i)+\alpha|N|α\alpha 是懲罰因子,該值越大則模型複雜度的懲罰越大;N|N| 代表該節點下游子節點的數目。如果剪枝前的損失函數值大於剪枝後的值,則對該節點進行剪枝。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章