一. 前言
XGBoost是由陳天奇大神設計的一套基於gbdt的可並行計算的機器學習工具。在kaggle、天池等大數據競賽有着廣泛的應用。通過閱讀論文和代碼,受益良多,並總結了包括公式推導、並行化設計和源碼剖析等一系列筆記。本篇主要梳理了論文中的公式推導,並添加了一些推導的細節和自己的想法。如有錯誤,還需指正。
二. 從Boosting模型開始
說道GBDT或者XGBoost,就不得不說起Boosting模型,Boosting 是一種將弱分類器轉化爲強分類器的方法,它的函數模型是具有疊加性的。具體可表示爲
D代表數據集,m維,數量爲n。其中每棵樹(每次迭代)都是一個樹模型,可表示爲
q(x)代表樣本x到的樹模型的葉子節點的映射關係。w是樹模型中用來擬合屬於各自葉子節點的樣本的預測值。第二輪開始,每輪訓練輸入爲上一輪預測與真實值的殘差,最後的葉子節點的結果也爲殘差的預測,最後所有輪加一起即爲所求。
三. 目標函數的定義
目標函數是XGBOOST的一個特點,爲了防止過擬合,XGBoost的目標函數由損失函數和複雜度組成。複雜度又由葉子數量和L2正則組成。
其中i是樣本id,k是樹id(輪數),由於loss和複雜度項都是凸函數,所以有最小值。
這個也很好理解,w是與真實值的殘差,將w的L2正則加在目標函數中,可以有效防止過擬合。(L2正則是很常用的規則化公式,正則的知識請自行補充)
同樣葉子數的線性項也加在了目標函數中,一定程度上限制葉子數量,防止過擬合。而傳統GDBT方法防止過擬合的手段無論是預剪枝還是後剪枝,都是額外進行交叉驗證的步驟。
四.推導出最優估計
目標函數確定了,接下來就是訓練的過程了。對於每次迭代過程,可以將一棵樹的訓練目標函數寫成形式如下:
輸入是t-1輪後預測的值,真實值,用來擬合殘差f(x)。對於這個式子,我們不知道loss的具體形式,所以無法對f(x)進行有效的最優估計。於是進行如下推導:
首先將目標函數泰勒二階展開近似
去掉常數項即爲:
正則項展開
首先做一下轉換:
公式轉化爲:
這是一個二次項形式,所以最後使得目標函數最小的w爲
帶入原方程最小值爲
那麼最後求得的w就是目標函數在一個樣本集合條件下的最優解。
爲什麼要推導呢?
1. 適應各種損失函數。
2. 正則項的加入對參數估計產生了影響,要具體找出影響是什麼。
如果只考慮平方損失的條件下,在沒有正則項的情況下參數的最優估計爲樣本均值。在沒有指定損失函數情況下,我們也很容易想到均值是給定樣本條件下的誤差最小的最優估計。但是損失函數換成絕對值損失,那麼最優估計就爲中位數。可見不同損失函數下,結果並不想我們想的那麼簡單。陳天奇大神在博客中也說,推導是爲了使得模型更具有一般性。
爲什麼用泰勒二階近似展開?
首先我們看一下GBDT的算法流程
由於gbdt只用到了一階信息,如果按照上文中推導,相當於loss只進行了一階泰勒展開。在沒有複雜度項的情況下,無法確定步長,所以只能用常數步長根據一階梯度方向去逼近。這就是牛頓下降法和梯度下降法的區別。由於二階展開用二次函數去逼近函數,所以可以利用二階信息確定更新步長,比只利用一階信息的gdbt用更少的迭代獲得更好的效果。感興趣可以自己證明牛頓梯度法的二階收斂性。
五 貪心算法求解
但是滿足給定樣本屬性條件下決策樹有很多,找到其中能使目標函數最優的決策樹是一個nphard問題,這裏使用了貪心算法的策略來近似求解。
假設存在一個遊標,先將分支下的樣本按照一個屬性進行排序,再滑動遊標按照這個屬性從小到大遍歷,分別將樣本分爲兩部分,將兩部分的樣本帶入上一節求得的目標函數最小值公式,再相加,與不分裂的樣本的目標函數做差,即爲分裂後的“收益”。貪心的策略就是找出收益最大的分裂“遊標”的位置。
到這裏,再來反推一下目標函數的設計思路,如果沒有加入複雜度,那麼這裏的損失函數即爲:
可以預想到,兩個參數估計一組樣本肯定比一個參數估計一組樣本要好,所以樹會一直分裂下去。按照傳統cart樹,會有額外的根據交叉驗證的預剪枝或者後剪枝。但是xgboost通過巧妙的設計目標函數,先是在分母上加一個λ,來降低分支的收益“靈敏度”,這個“靈敏度”可以通過修改此參數來控制。當收益小於一個閾值則剪枝,從而達到防止過擬合的目的。(類似的在分母加上常數項的處理,在樸素貝葉斯單個屬性條件概率的計算中也有用到)
另外的一個參數γ,從最後每次分割後的收益函數可以看到,這個參數的“物理意義”就是每分裂一次,減去一個視爲懲罰的常數。
六.簡單的改進方向
這裏的貪心算法,是爲了提高效率而採取的近似算法。由於是近似算法,所以有一定的改進空間,具體也要結合實際的樣本情況。
在參加天池大數據競賽機場賽的時候,由於有數據在時間上有很強的週期波動性,我自己實現了一個簡單的gdbt,只對時間屬性上的分支,採用了在每一次分割用兩個遊標,分割成三個子節點的方法,實際效果,要比相同特徵情況下的xgboost準確率高。