集成樹之GBDT算法

目錄:

  1. 導言
  2. GBDT算法
    2.1 GBDT迴歸樹
    2.2 GBDT分類樹
    2.2.1 二分類
    2.2.2 多分類
  3. Bosting提升樹
    3.1 爲什麼Boosting能夠降低偏差

1 導言

常見的集成學習算法一般可以分爲兩大類:Boosting和Bagging。Boosting類的核心思想在於訓練多個弱分類器,最終的結果是這n個弱分類器的和,每一個弱分類器的目標是學習前m個弱分類器的和與樣本label的殘差,因此Boosting這類算法需要弱分類器足夠簡單,並且是低方差、高偏差的模型,因爲訓練的過程是通過降低偏差不斷提高精度的
Boosting的典型代表便是GBDT,在GBDT的基礎上還延伸出XgBoost、LightGBM、CatBoost等算法。Bagging的代表是隨機森林算法,這類算法的目標是學習多個弱分類器,最終的結果是這n個弱分類器的加權平均,其中每個弱分類器的目標都是直接對樣本進行學習,他們之間沒有強相關性。

2 GBDT算法

GBDT是一種由CART迴歸樹構成的集成學習算法,可用做分類、迴歸任務。關於決策樹的介紹,可見另一篇文章:深入理解決策樹
GBDT算法的核心思想:通過多輪迭代學習,每次迭代產生一個弱分類器,每個分類器的目標是對上一輪分類器的殘差,也就是彌補前n個分類器的不足。

2.1 GBDT迴歸樹

模型可表示爲:Fm(x)=i=1mT(x;θm)F_m(x)=\sum_{i=1}^mT(x;{\theta}_m)
假設模型一共mm輪,每輪產生一個弱學習器T(x;θm)T(x;{\theta}_m),那麼弱學習器的損失函數可表示爲
θm^=argminθmLiNL(yi,Fm)\hat{{\theta}_m}=\arg\underset{{\theta}_m}{\min}L\sum_i^NL(y_i,F_m)
其中Fm=Fm1(xi)+T(xi;θm)F_m=F_{m-1}(x_i)+T(x_i;{\theta}_m)Fm1(x)F_{m-1}(x)爲當前模型,GBDT通過經驗風險極小化學習下一個弱學習器的參數。
損失函數LL可選擇平方損失函數,我們讓損失函數沿殘差梯度方向下降。
在訓練每個弱學習器時,利用損失函數的負梯度當作該弱學習器的目標去學習,也就是用殘差去擬合一棵迴歸樹。負梯度可表示爲
rti=L(yi,f(xi))f(xi)f(x)=ft1(x)r_{ti}=-\frac{{\partial}L(y_i,f(x_i))}{{\partial}f(x_i)}_{f(x)=f_{t-1}(x)}
利用(xi,rti),(i=1,2,...,m)(x_i,r_{ti}),(i=1,2,...,m)去擬合一棵CART迴歸樹。
當然,算法工程師的面試是少不了手推滴,下面來手推一下:
算法流程

  1. 初始化弱學習器:
    KaTeX parse error: Got function '\min' with no arguments as argument to '\underset' at position 29: …nderset{\gamma}\̲m̲i̲n̲\sum_{i=1}^NL(y…
  2. 迭代m=1,2,...,Mm=1,2,...,M,對每個樣本,計算其負梯度:
    γti=L(yi,f(xi))f(xi)f(x)=ft1(x)\gamma_{ti}=-\frac{{\partial}L(y_i,f(x_i))}{{\partial}f(x_i)}_{f(x)=f_{t-1}(x)}
  3. 將上式中的γim\gamma_{im}作爲新的樣本label,並將數據(xi,γim),(i=1,2,...,N)(x_i,\gamma_{im}),(i=1,2,...,N)作爲下棵樹的訓練數據,得到新的迴歸樹fm(x)f_m(x),其對應的葉子結點區域爲Rjm,j=1,2,...,JR_{jm},j=1,2,...,J,其中JJ爲迴歸樹的葉子結點個數。
  4. 對於葉子區域j=1,2,...,Jj=1,2,...,J,計算最佳擬合值:
    KaTeX parse error: Got function '\min' with no arguments as argument to '\underset' at position 34: …nderset{\gamma}\̲m̲i̲n̲\sum_{x_i\in{R_…
  5. 更新強學習器:
    fm(x)=fm1(x)+j=1JγjmI(xRjm)f_m(x)=f_{m-1}(x)+\sum_{j=1}^J\gamma_{jm}I(x\in{R_{jm}})
2.1.1 shrinkge

如果在提升樹算法上使用shrinkge能夠在一定程度上防止過擬合,其主要思想爲:

每次迭代走一小步逼近label,通過多次迭代的方式,要比每次邁一大步,通過較少次迭代逼近label的方式更容易避免過擬合。簡單來說就是它不完全信任每一棵殘差樹,它認爲每棵樹只學習到了真理的一部分,所以在累加計算的時候只給它較小的權重,通過多學習出來幾棵樹來彌補不足。

  1. 沒有shrinkge時:
    y(i+1)=(y1yi)y_(i+1)=殘差(y_1\sim{y_i})
    (y1yi)=yy(1yi)殘差(y_1\sim{y_i})=y_真-y(1\sim{y_i})
    y(1yi)=sum(y1,...,yi)y(1\sim{y_i})=sum(y1,...,y_i)
  2. 有shrinkge時:
    y(1yi)=y(1yi1)+stepyiy(1\sim{y_i})=y(1\sim{y_{i-1}})+step*y_i
    step一般取0.0010.010.001~0.01,相當於爲每棵樹設置一個權重,導致各個樹的殘差是漸變的,而不是陡變的。

2.2 GBDT分類樹

2.2.1二元分類

假設該二分類爲0,1分類,假設訓練樣本中label爲1的比例爲P1P1
算法流程

  1. F0(x)=h(x)=logP11P1F_0(x)=h(x)=log\frac{P_1}{1-P_1}
  2. For m=1,2,…,M:
    a.計算gi=yi^yig_i=\hat{y_i}-y_i,用{(xi,gi)}i=1n\left\{ (x_i,-g_i) \right\}_{i=1}^n訓練一棵迴歸樹tm(x)t_m(x),其中yi^=11+eFm(x)\hat{y_i}=\frac{1}{1+e^{-F_m(x)}}
    b.通過一緯最小化損失函數找到最優權重:
    ρm=argminρiloss(xi,yiFm1(x)+ρtm(x))\rho_m= \mathop{\arg\min}_{\rho}\sum_i{loss(x_i,y_i\big|F_{m-1}(x)+\rho{t_m(x)})}
    c.考慮shrinkge,可得
    Fm(x)=Fm1(x)+αρmtm(x)F_m(x)=F_{m-1}(x)+\alpha\rho_mt_m(x)
    其中,α\alpha爲學習率。
2.2.2 多分類

對於多分類情況,則要考慮以下softmax模型:
P(y=1x)=eF1(x)i=1keFi(x)P(y=1\big|x)=\frac{e^{F_1(x)}}{\sum_{i=1}^ke^{F_i(x)}}
P(y=2x)=eF2(x)i=1keFi(x)P(y=2\big|x)=\frac{e^{F_2(x)}}{\sum_{i=1}^ke^{F_i(x)}}
......
P(y=kx)=eFk(x)i=1keFi(x)P(y=k\big|x)=\frac{e^{F_k(x)}}{\sum_{i=1}^ke^{F_i(x)}}
F1,F2,...,FkF_1,F_2,...,F_kkk個不同的tree ensemble,每一輪的訓練實際上是訓練了kk棵樹去擬合softmax的每一個分支的負梯度。
softmax模型的單樣本損失函數可表示爲:
loss=i=1kyilogP(yix)=i=1kyilogeFi(x)j=1keFj(x)loss=-\sum_{i=1}^ky_ilogP(y_i\big|x)=-\sum_{i=1}^ky_i*log\frac{e^{F_i(x)}}{\sum_{j=1}^ke^{F_j(x)}}
這裏的yiy_i是樣本label在kk個類別上做one-hot編碼之後的取值,只有一維爲1,其餘爲0。
LossFq=yqeFq(z)j=1keFj(x)=yqyq^-\frac{\partial Loss}{\partial F_q}=y_q-\frac{e^{F_q(z)}}{\sum_{j=1}^ke^{F_j(x)}}=y_q-\hat{y_q}
所以這kk棵樹同樣是擬合了樣本的真實標籤與預測概率之差,與二分類樹在本質上是一致的。

3 爲什麼Boosting能夠降低偏差

對於Boosting提升樹算法來說,偏差可以表示爲:
E(F)=γimE(fi)E(F)=\gamma*\sum_i^mE(f_i)
方差可以表示爲:
Var(F)=m2γ2σ2ρ+mγ2σ2(1ρ)Var(F)=m^2\gamma^2\sigma^2\rho+m\gamma^2\sigma^2(1-\rho)
=m2γ2σ21+mγ2σ20=m^2\gamma^2\sigma^2*1+m\gamma^2\sigma^2*0
=m2γ2σ2=m^2\gamma^2\sigma^2
所以如果基學習器如果不是弱模型的話,方差會較大,這將會導致整個模型的方差較大,嚴重過擬合。
因爲基學習器是弱模型,所以精度不高,隨着基學習器數量mm的增大,整個模型的準確度會提高,更接近真實值,但不會無限逼近於1,因爲隨着基學習器的數量mm提高,整個模型的方差也會變大,抗過擬合能力降低,在一定程度上會導致準確度下降。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章