集成學習:lightGBM(一)

日萌社

人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度學習實戰(不定時更新)


集成學習:Bagging、隨機森林、Boosting、GBDT

集成學習:XGBoost

集成學習:lightGBM(一)

集成學習:lightGBM(二)


5.5 lightGBM

1 寫在介紹lightGBM之前

1.1 lightGBM演進過程

1.2 AdaBoost算法

AdaBoost是一種提升樹的方法,和三個臭皮匠,賽過諸葛亮的道理一樣。

AdaBoost兩個問題:

  • (1) 如何改變訓練數據的權重或概率分佈
    • 提高前一輪被弱分類器錯誤分類的樣本的權重,降低前一輪被分對的權重
  • (2) 如何將弱分類器組合成一個強分類器,亦即,每個分類器,前面的權重如何設置
    • 採取”多數表決”的方法.加大分類錯誤率小的弱分類器的權重,使其作用較大,而減小分類錯誤率大的弱分類器的權重,使其在表決中起較小的作用。

1.3 GBDT算法以及優缺點

GBDT和AdaBosst很類似,但是又有所不同。

  • GBDT和其它Boosting算法一樣,通過將表現一般的幾個模型(通常是深度固定的決策樹)組合在一起來集成一個表現較好的模型。
  • AdaBoost是通過提升錯分數據點的權重來定位模型的不足, Gradient Boosting通過負梯度來識別問題,通過計算負梯度來改進模型,即通過反覆地選擇一個指向負梯度方向的函數,該算法可被看做在函數空間裏對目標函數進行優化。

因此可以說 。

缺點:

GBDT ->預排序方法(pre-sorted)

  • (1) 空間消耗大
    • 這樣的算法需要保存數據的特徵值,還保存了特徵排序的結果(例如排序後的索引,爲了後續快速的計算分割點),這裏需要消耗訓練數據兩倍的內存
  • (2) 時間上也有較大的開銷。
    • 在遍歷每一個分割點的時候,都需要進行分裂增益的計算,消耗的代價大。
  • (3) 對內存(cache)優化不友好。
    • 在預排序後,特徵對梯度的訪問是一種隨機訪問,並且不同的特徵訪問的順序不一樣,無法對cache進行優化。
    • 同時,在每一層長樹的時候,需要隨機訪問一個行索引到葉子索引的數組,並且不同特徵訪問的順序也不一樣,也會造成較大的cache miss。

1.4 啓發

常用的機器學習算法,例如神經網絡等算法,都可以以mini-batch的方式訓練,訓練數據的大小不會受到內存限制。

而GBDT在每一次迭代的時候,都需要遍歷整個訓練數據多次。

如果把整個訓練數據裝進內存則會限制訓練數據的大小;如果不裝進內存,反覆地讀寫訓練數據又會消耗非常大的時間。

尤其面對工業級海量的數據,普通的GBDT算法是不能滿足其需求的。

LightGBM提出的主要原因就是爲了解決GBDT在海量數據遇到的問題,讓GBDT可以更好更快地用於工業實踐。

2 什麼是lightGBM

lightGBM是2017年1月,微軟在GItHub上開源的一個新的梯度提升框架。

github介紹鏈接

在開源之後,就被別人冠以“速度驚人”、“支持分佈式”、“代碼清晰易懂”、“佔用內存小”等屬性。

LightGBM主打的高效並行訓練讓其性能超越現有其他boosting工具。在Higgs數據集上的試驗表明,LightGBM比XGBoost快將近10倍,內存佔用率大約爲XGBoost的1/6。

higgs數據集介紹:這是一個分類問題,用於區分產生希格斯玻色子的信號過程和不產生希格斯玻色子的信號過程。

數據鏈接

3 lightGBM原理

lightGBM 主要基於以下方面優化,提升整體特特性:

  1. 基於Histogram(直方圖)的決策樹算法
  2. Lightgbm 的Histogram(直方圖)做差加速
  3. 帶深度限制的Leaf-wise的葉子生長策略
  4. 直接支持類別特徵
  5. 直接支持高效並行

具體解釋見下,分節介紹。


3.1 基於Histogram(直方圖)的決策樹算法

直方圖算法的基本思想是

  • 先把連續的浮點特徵值離散化成k個整數,同時構造一個寬度爲k的直方圖。
  • 在遍歷數據的時候,根據離散化後的值作爲索引在直方圖中累積統計量,當遍歷一次數據後,直方圖累積了需要的統計量,然後根據直方圖的離散值,遍歷尋找最優的分割點。

Eg:

[0, 0.1) --> 0;

[0.1,0.3) --> 1;

...

使用直方圖算法有很多優點。首先,最明顯就是內存消耗的降低,直方圖算法不僅不需要額外存儲預排序的結果,而且可以只保存特徵離散化後的值,而這個值一般用8位整型存儲就足夠了,內存消耗可以降低爲原來的1/8。

然後在計算上的代價也大幅降低,預排序算法每遍歷一個特徵值就需要計算一次分裂的增益,而直方圖算法只需要計算k次(k可以認爲是常數),時間複雜度從O(#data#feature)優化到O(k#features)。

當然,Histogram算法並不是完美的。由於特徵被離散化後,找到的並不是很精確的分割點,所以會對結果產生影響。但在不同的數據集上的結果表明,離散化的分割點對最終的精度影響並不是很大,甚至有時候會更好一點。原因是決策樹本來就是弱模型,分割點是不是精確並不是太重要;較粗的分割點也有正則化的效果,可以有效地防止過擬合;即使單棵樹的訓練誤差比精確分割的算法稍大,但在梯度提升(Gradient Boosting)的框架下沒有太大的影響。

3.2 Lightgbm 的Histogram(直方圖)做差加速

一個葉子的直方圖可以由它的父親節點的直方圖與它兄弟的直方圖做差得到。

通常構造直方圖,需要遍歷該葉子上的所有數據,但直方圖做差僅需遍歷直方圖的k個桶。

利用這個方法,LightGBM可以在構造一個葉子的直方圖後,可以用非常微小的代價得到它兄弟葉子的直方圖,在速度上可以提升一倍。

3.3 帶深度限制的Leaf-wise的葉子生長策略

Level-wise便利一次數據可以同時分裂同一層的葉子,容易進行多線程優化,也好控制模型複雜度,不容易過擬合。

  • 但實際上Level-wise是一種低效的算法,因爲它不加區分的對待同一層的葉子,帶來了很多沒必要的開銷,因爲實際上很多葉子的分裂增益較低,沒必要進行搜索和分裂。

Leaf-wise則是一種更爲高效的策略,每次從當前所有葉子中,找到分裂增益最大的一個葉子,然後分裂,如此循環。

  • 因此同Level-wise相比,在分裂次數相同的情況下,Leaf-wise可以降低更多的誤差,得到更好的精度。
  • Leaf-wise的缺點是可能會長出比較深的決策樹,產生過擬合。因此LightGBM在Leaf-wise之上增加了一個最大深度的限制,在保證高效率的同時防止過擬合。

3.4 直接支持類別特徵

實際上大多數機器學習工具都無法直接支持類別特徵,一般需要把類別特徵,轉化到多維的0/1特徵,降低了空間和時間的效率。

而類別特徵的使用是在實踐中很常用的。基於這個考慮,LightGBM優化了對類別特徵的支持,可以直接輸入類別特徵,不需要額外的0/1展開。並在決策樹算法上增加了類別特徵的決策規則。

在Expo數據集上的實驗,相比0/1展開的方法,訓練速度可以加速8倍,並且精度一致。目前來看,LightGBM是第一個直接支持類別特徵的GBDT工具。

Expo數據集介紹:數據包含1987年10月至2008年4月美國境內所有商業航班的航班到達和離開的詳細信息。這是一個龐大的數據集:總共有近1.2億條記錄。主要用於預測航班是否準時。

數據鏈接

3.5 直接支持高效並行

LightGBM還具有支持高效並行的優點。LightGBM原生支持並行學習,目前支持特徵並行和數據並行的兩種。

  • 特徵並行的主要思想是在不同機器在不同的特徵集合上分別尋找最優的分割點,然後在機器間同步最優的分割點。
  • 數據並行則是讓不同的機器先在本地構造直方圖,然後進行全局的合併,最後在合併的直方圖上面尋找最優分割點。

LightGBM針對這兩種並行方法都做了優化:

  • 特徵並行算法中,通過在本地保存全部數據避免對數據切分結果的通信;

數據並行中使用分散規約 (Reduce scatter) 把直方圖合併的任務分攤到不同的機器,降低通信和計算,並利用直方圖做差,進一步減少了一半的通信量。

基於投票的數據並行(Voting Parallelization)則進一步優化數據並行中的通信代價,使通信代價變成常數級別。在數據量很大的時候,使用投票並行可以得到非常好的加速效果。

4 小結

  • lightGBM 演進過程

lightGBM優勢

  • 基於Histogram(直方圖)的決策樹算法
  • Lightgbm 的Histogram(直方圖)做差加速
  • 帶深度限制的Leaf-wise的葉子生長策略
  • 直接支持類別特徵
  • 直接支持高效並行

5.6 lightGBM算法api介紹

1 lightGBM的安裝

  • windows下:
pip3 install lightgbm

2 lightGBM參數介紹

2.1 Control Parameters

Control Parameters 含義 用法
max_depth 樹的最大深度 當模型過擬合時,可以考慮首先降低 max_depth
min_data_in_leaf 葉子可能具有的最小記錄數 默認20,過擬合時用
feature_fraction 例如 爲0.8時,意味着在每次迭代中隨機選擇80%的參數來建樹 boosting 爲 random forest 時用
bagging_fraction 每次迭代時用的數據比例 用於加快訓練速度和減小過擬合
early_stopping_round 如果一次驗證數據的一個度量在最近的early_stopping_round 回合中沒有提高,模型將停止訓練 加速分析,減少過多迭代
lambda 指定正則化 0~1
min_gain_to_split 描述分裂的最小 gain 控制樹的有用的分裂
max_cat_group 在 group 邊界上找到分割點 當類別數量很多時,找分割點很容易過擬合時
n_estimators 最大迭代次數 最大迭代數不必設置過大,可以在進行一次迭代後,根據最佳迭代數設置

2.2 Core Parameters

Core Parameters 含義 用法
Task 數據的用途 選擇 train 或者 predict
application 模型的用途 選擇 regression: 迴歸時,
binary: 二分類時,
multiclass: 多分類時
boosting 要用的算法 gbdt,
rf: random forest,
dart: Dropouts meet Multiple Additive Regression Trees,
goss: Gradient-based One-Side Sampling
num_boost_round 迭代次數 通常 100+
learning_rate 學習率 常用 0.1, 0.001, 0.003…
num_leaves 葉子數量 默認 31
device   cpu 或者 gpu
metric   mae: mean absolute error ,
mse: mean squared error ,
binary_logloss: loss for binary classification ,
multi_logloss: loss for multi classification

2.3 IO parameter

IO parameter 含義
max_bin 表示 feature 將存入的 bin 的最大數量
categorical_feature 如果 categorical_features = 0,1,2, 則列 0,1,2是 categorical 變量
ignore_column 與 categorical_features 類似,只不過不是將特定的列視爲categorical,而是完全忽略
save_binary 這個參數爲 true 時,則數據集被保存爲二進制文件,下次讀數據時速度會變快

3 調參建議

IO parameter 含義
num_leaves 取值應 <= 2^{(max\_depth)}2​(max_depth)​​, 超過此值會導致過擬合
min_data_in_leaf 將它設置爲較大的值可以避免生長太深的樹,但可能會導致 underfitting,在大型數據集時就設置爲數百或數千
max_depth 這個也是可以限制樹的深度

下表對應了 Faster Speed ,better accuracy ,over-fitting 三種目的時,可以調的參數

Faster Speed better accuracy over-fitting
將 max_bin 設置小一些 用較大的 max_bin max_bin 小一些
  num_leaves 大一些 num_leaves 小一些
用 feature_fraction來做 sub-sampling   用 feature_fraction
用 bagging_fraction 和 bagging_freq   設定 bagging_fraction 和 bagging_freq
  training data 多一些 training data 多一些
用 save_binary來加速數據加載 直接用 categorical feature 用 gmin_data_in_leaf 和 min_sum_hessian_in_leaf
用 parallel learning 用 dart 用 lambda_l1, lambda_l2 ,min_gain_to_split 做正則化
  num_iterations 大一些,learning_rate小一些 用 max_depth 控制樹的深度

5.7 lightGBM案例介紹

接下來,通過鳶尾花數據集對lightGBM的基本使用,做一個介紹。

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import mean_squared_error
import lightgbm as lgb

加載數據,對數據進行基本處理

# 加載數據
iris = load_iris()
data = iris.data
target = iris.target
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)

模型訓練

gbm = lgb.LGBMRegressor(objective='regression', learning_rate=0.05, n_estimators=20)

gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], eval_metric='l1', early_stopping_rounds=5)

gbm.score(X_test, y_test)
# 0.810605595102488

#  網格搜索,參數優化
estimator = lgb.LGBMRegressor(num_leaves=31)
param_grid = {
    'learning_rate': [0.01, 0.1, 1],
    'n_estimators': [20, 40]
}
gbm = GridSearchCV(estimator, param_grid, cv=4)
gbm.fit(X_train, y_train)
print('Best parameters found by grid search are:', gbm.best_params_)
# Best parameters found by grid search are: {'learning_rate': 0.1, 'n_estimators': 40}

模型調優訓練

gbm = lgb.LGBMRegressor(num_leaves=31, learning_rate=0.1, n_estimators=40)

gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], eval_metric='l1', early_stopping_rounds=5)

gbm.score(X_test, y_test)
# 0.9536626296481988

In [1]:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import mean_squared_error
import lightgbm as lgb

讀取數據

In [2]:

iris = load_iris()
data = iris.data
target = iris.target

In [3]:

data

Out[3]:

array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.2],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.6, 1.4, 0.1],
       [4.4, 3. , 1.3, 0.2],
       [5.1, 3.4, 1.5, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [4.5, 2.3, 1.3, 0.3],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.1, 3.8, 1.9, 0.4],
       [4.8, 3. , 1.4, 0.3],
       [5.1, 3.8, 1.6, 0.2],
       [4.6, 3.2, 1.4, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [5. , 3.3, 1.4, 0.2],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.9, 3.1, 4.9, 1.5],
       [5.5, 2.3, 4. , 1.3],
       [6.5, 2.8, 4.6, 1.5],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [4.9, 2.4, 3.3, 1. ],
       [6.6, 2.9, 4.6, 1.3],
       [5.2, 2.7, 3.9, 1.4],
       [5. , 2. , 3.5, 1. ],
       [5.9, 3. , 4.2, 1.5],
       [6. , 2.2, 4. , 1. ],
       [6.1, 2.9, 4.7, 1.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.7, 3.1, 4.4, 1.4],
       [5.6, 3. , 4.5, 1.5],
       [5.8, 2.7, 4.1, 1. ],
       [6.2, 2.2, 4.5, 1.5],
       [5.6, 2.5, 3.9, 1.1],
       [5.9, 3.2, 4.8, 1.8],
       [6.1, 2.8, 4. , 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.4, 2.9, 4.3, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [6.8, 2.8, 4.8, 1.4],
       [6.7, 3. , 5. , 1.7],
       [6. , 2.9, 4.5, 1.5],
       [5.7, 2.6, 3.5, 1. ],
       [5.5, 2.4, 3.8, 1.1],
       [5.5, 2.4, 3.7, 1. ],
       [5.8, 2.7, 3.9, 1.2],
       [6. , 2.7, 5.1, 1.6],
       [5.4, 3. , 4.5, 1.5],
       [6. , 3.4, 4.5, 1.6],
       [6.7, 3.1, 4.7, 1.5],
       [6.3, 2.3, 4.4, 1.3],
       [5.6, 3. , 4.1, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [6.1, 3. , 4.6, 1.4],
       [5.8, 2.6, 4. , 1.2],
       [5. , 2.3, 3.3, 1. ],
       [5.6, 2.7, 4.2, 1.3],
       [5.7, 3. , 4.2, 1.2],
       [5.7, 2.9, 4.2, 1.3],
       [6.2, 2.9, 4.3, 1.3],
       [5.1, 2.5, 3. , 1.1],
       [5.7, 2.8, 4.1, 1.3],
       [6.3, 3.3, 6. , 2.5],
       [5.8, 2.7, 5.1, 1.9],
       [7.1, 3. , 5.9, 2.1],
       [6.3, 2.9, 5.6, 1.8],
       [6.5, 3. , 5.8, 2.2],
       [7.6, 3. , 6.6, 2.1],
       [4.9, 2.5, 4.5, 1.7],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 2.5, 5.8, 1.8],
       [7.2, 3.6, 6.1, 2.5],
       [6.5, 3.2, 5.1, 2. ],
       [6.4, 2.7, 5.3, 1.9],
       [6.8, 3. , 5.5, 2.1],
       [5.7, 2.5, 5. , 2. ],
       [5.8, 2.8, 5.1, 2.4],
       [6.4, 3.2, 5.3, 2.3],
       [6.5, 3. , 5.5, 1.8],
       [7.7, 3.8, 6.7, 2.2],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.2, 5. , 1.5],
       [6.9, 3.2, 5.7, 2.3],
       [5.6, 2.8, 4.9, 2. ],
       [7.7, 2.8, 6.7, 2. ],
       [6.3, 2.7, 4.9, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [7.2, 3.2, 6. , 1.8],
       [6.2, 2.8, 4.8, 1.8],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [7.2, 3. , 5.8, 1.6],
       [7.4, 2.8, 6.1, 1.9],
       [7.9, 3.8, 6.4, 2. ],
       [6.4, 2.8, 5.6, 2.2],
       [6.3, 2.8, 5.1, 1.5],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.4, 5.6, 2.4],
       [6.4, 3.1, 5.5, 1.8],
       [6. , 3. , 4.8, 1.8],
       [6.9, 3.1, 5.4, 2.1],
       [6.7, 3.1, 5.6, 2.4],
       [6.9, 3.1, 5.1, 2.3],
       [5.8, 2.7, 5.1, 1.9],
       [6.8, 3.2, 5.9, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 5. , 1.9],
       [6.5, 3. , 5.2, 2. ],
       [6.2, 3.4, 5.4, 2.3],
       [5.9, 3. , 5.1, 1.8]])

In [4]:

target

Out[4]:

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

數據基本處理

In [5]:

X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)

模型訓練

模型基本訓練

In [14]:

gbm = lgb.LGBMRegressor(objective="regression", learning_rate=0.05, n_estimators=20)

gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], eval_metric="l1", early_stopping_rounds=3)
gbm.score(X_test, y_test)

[1]	valid_0's l1: 0.653531	valid_0's l2: 0.626219
Training until validation scores don't improve for 3 rounds
[2]	valid_0's l1: 0.626209	valid_0's l2: 0.57348
[3]	valid_0's l1: 0.60108	valid_0's l2: 0.525437
[4]	valid_0's l1: 0.577988	valid_0's l2: 0.482521
[5]	valid_0's l1: 0.555301	valid_0's l2: 0.443297
[6]	valid_0's l1: 0.534806	valid_0's l2: 0.408881
[7]	valid_0's l1: 0.510834	valid_0's l2: 0.372852
[8]	valid_0's l1: 0.491373	valid_0's l2: 0.344015
[9]	valid_0's l1: 0.469678	valid_0's l2: 0.314384
[10]	valid_0's l1: 0.451908	valid_0's l2: 0.290418
[11]	valid_0's l1: 0.433932	valid_0's l2: 0.268274
[12]	valid_0's l1: 0.414266	valid_0's l2: 0.245211
[13]	valid_0's l1: 0.398027	valid_0's l2: 0.227095
[14]	valid_0's l1: 0.380293	valid_0's l2: 0.208076
[15]	valid_0's l1: 0.365621	valid_0's l2: 0.193252
[16]	valid_0's l1: 0.34957	valid_0's l2: 0.177498
[17]	valid_0's l1: 0.336313	valid_0's l2: 0.16537
[18]	valid_0's l1: 0.321785	valid_0's l2: 0.152308
[19]	valid_0's l1: 0.310088	valid_0's l2: 0.142386
[20]	valid_0's l1: 0.298266	valid_0's l2: 0.131543
Did not meet early stopping. Best iteration is:
[20]	valid_0's l1: 0.298266	valid_0's l2: 0.131543

Out[14]:

0.7578964818630016

通過網格搜索進行訓練

In [11]:

estimators = lgb.LGBMRegressor(num_leaves=31)
param_grid = {
    "learning_rate": [0.01, 0.1, 1],
    "n_estmators":[20, 40, 60, 80]
}
gbm = GridSearchCV(estimators, param_grid, cv=5)
gbm.fit(X_train, y_train)

Out[11]:

GridSearchCV(cv=5, error_score=nan,
             estimator=LGBMRegressor(boosting_type='gbdt', class_weight=None,
                                     colsample_bytree=1.0,
                                     importance_type='split', learning_rate=0.1,
                                     max_depth=-1, min_child_samples=20,
                                     min_child_weight=0.001, min_split_gain=0.0,
                                     n_estimators=100, n_jobs=-1, num_leaves=31,
                                     objective=None, random_state=None,
                                     reg_alpha=0.0, reg_lambda=0.0, silent=True,
                                     subsample=1.0, subsample_for_bin=200000,
                                     subsample_freq=0),
             iid='deprecated', n_jobs=None,
             param_grid={'learning_rate': [0.01, 0.1, 1],
                         'n_estmators': [20, 40, 60, 80]},
             pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
             scoring=None, verbose=0)

In [12]:

gbm.best_params_

Out[12]:

{'learning_rate': 0.1, 'n_estmators': 20}

In [13]:

gbm = lgb.LGBMRegressor(objective="regression", learning_rate=0.1, n_estimators=20)

gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], eval_metric="l1", early_stopping_rounds=3)
gbm.score(X_test, y_test)

[1]	valid_0's l1: 0.625261	valid_0's l2: 0.571453
Training until validation scores don't improve for 3 rounds
[2]	valid_0's l1: 0.574385	valid_0's l2: 0.477181
[3]	valid_0's l1: 0.531459	valid_0's l2: 0.403427
[4]	valid_0's l1: 0.483888	valid_0's l2: 0.33428
[5]	valid_0's l1: 0.447306	valid_0's l2: 0.284716
[6]	valid_0's l1: 0.413883	valid_0's l2: 0.243537
[7]	valid_0's l1: 0.377047	valid_0's l2: 0.203656
[8]	valid_0's l1: 0.348048	valid_0's l2: 0.175576
[9]	valid_0's l1: 0.318049	valid_0's l2: 0.148479
[10]	valid_0's l1: 0.29463	valid_0's l2: 0.129983
[11]	valid_0's l1: 0.27226	valid_0's l2: 0.111468
[12]	valid_0's l1: 0.2489	valid_0's l2: 0.0960426
[13]	valid_0's l1: 0.230634	valid_0's l2: 0.0833998
[14]	valid_0's l1: 0.216687	valid_0's l2: 0.0759234
[15]	valid_0's l1: 0.1993	valid_0's l2: 0.0670385
[16]	valid_0's l1: 0.188099	valid_0's l2: 0.0622206
[17]	valid_0's l1: 0.178022	valid_0's l2: 0.058299
[18]	valid_0's l1: 0.168954	valid_0's l2: 0.0551119
[19]	valid_0's l1: 0.158303	valid_0's l2: 0.0505529
[20]	valid_0's l1: 0.149623	valid_0's l2: 0.0466022
Did not meet early stopping. Best iteration is:
[20]	valid_0's l1: 0.149623	valid_0's l2: 0.0466022

Out[13]:

0.9142290887795029
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章