對比XGBoost.cv和sklearn中的交叉驗證

寫在前面:已經很久很久很久沒有發博客了,有點愧疚還有點難過,不寫博客的實踐都幹嘛了,哎!!!

XGBoost有兩種接口:

  1. 原生接口,比如xgboost.trainxgboost.cv
  2. sklearn接口,比如xgboost.XGBClassifierxgboost.XGBRegressor

兩種接口有些許不同,比如原生接口的學習率參數是eta,sklearn接口的是learning_rate,原生接口要在traincv函數中傳入num_round作爲基學習器個數,而sklearn接口在定義模型時使用參數n_estimators。sklearn接口的形式與sklearn中的模型保持統一,方便sklearn用戶學習。

如果要對XGBoost模型進行交叉驗證,可以使用原生接口的交叉驗證函數xgboost.cv;對於sklearn接口,可以使用sklearn.model_selection中的cross_val_scorecross_validatevalidation_curve三個函數。

sklearn.model_selection中的三個函數區別:

  1. cross_val_score最簡單,返回模型給定參數的驗證得分,不能返回訓練得分
  2. cross_validate複雜一些,返回模型給定參數的訓練得分、驗證得分、訓練時間和驗證時間等,甚至還可以指定多個評價指標
  3. validation_curve返回模型指定一個參數的一系列候選值的訓練得分和驗證得分,可以通過判斷擬合情況來調整該參數,也可以用來畫validation_curve

下面分別以分類任務和迴歸任務展示一下四個函數的用法和輸出情況。經過對比,在參數相同的條件下,四個函數的輸出結果一致。發現了一個問題,validation_curvexgboost.cv的輸出結果大部分相同,但是前者的耗時卻比後者多了好幾倍。(暫時還找到原因,網上也沒找到相同的問題,打算到stackoverflow上問一下,如果有答案的話再回來補充)

20200402補充:初步懷疑是熱啓動的問題,在使用xgboost.cv進行交叉驗證時,可以通過熱啓動的方式訓練模型,此時只需要訓練NN棵樹;而把XGBRegressor傳入validation_curve進行交叉驗證,此時XGBRegressor不能設置熱啓動(而sklearn的GBDT和隨機森林都可以設置熱啓動),那就需要訓練1+2+...+N=N(N1)21+2+...+N = \frac{N*(N-1)}{2}棵樹,自然速度就慢了。

P.S. 下面代碼是用Jupyter Notebook寫的,懶得合併了。

import numpy as np
import xgboost as xgb

from sklearn.datasets import make_regression
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_validate
from sklearn.model_selection import KFold
from sklearn.model_selection import validation_curve
# 迴歸問題
X, y = make_regression(n_samples=10000, n_features=10)
# sklearn接口
n_estimators = 50
params = {'n_estimators':n_estimators, 'booster':'gbtree', 'max_depth':5, 'learning_rate':0.05,
          'objective':'reg:squarederror', 'subsample':1, 'colsample_bytree':1}
clf = xgb.XGBRegressor(**params)
cv = KFold(n_splits=5, shuffle=True, random_state=100)
print('test_score:', cross_val_score(clf, X, y, cv=cv, scoring='neg_mean_absolute_error'))
test_score: array([-57.48422753, -59.69255262, -58.91771172, -58.44347715,
       -59.8880623 ])
cross_validate(clf, X, y, cv=cv, scoring='neg_mean_absolute_error', 
               return_train_score=True)
{'fit_time': array([0.37278223, 0.36898613, 0.36637878, 0.36504936, 0.37162185]),
 'score_time': array([0.00398517, 0.00403547, 0.00398993, 0.00398636, 0.00404048]),
 'test_score': array([-57.48422753, -59.69255262, -58.91771172, -58.44347715,
        -59.8880623 ]),
 'train_score': array([-50.70099151, -50.43187094, -50.75229625, -50.66844022,
        -50.82982251])}
%%time # 計算一個cell的執行實踐

estimator_range = range(1, n_estimators+1)
train_score, test_score = validation_curve(
    clf, X, y, param_name='n_estimators', param_range=estimator_range,
    cv=cv, scoring='neg_mean_absolute_error'
)

print('train_score:',train_score[-1])
print('test_score:', test_score[-1])
train_score: [-50.70099151 -50.43187094 -50.75229625 -50.66844022 -50.82982251]
test_score: [-57.48422753 -59.69255262 -58.91771172 -58.44347715 -59.8880623 ]
Wall time: 57 s
print('train_mae_mean:\n', np.abs(train_score).mean(axis=1))
print('test_mae_mean:\n', np.abs(test_score).mean(axis=1))
train_mae_mean: 
 array([127.5682212 , 124.17645861, 120.96190697, 117.93807824,
        115.06161926, 112.28746068, 109.60911311, 107.07263957,
        104.63554663, 102.2788341 ,  99.97895509,  97.82509892,
         95.73958223,  93.71896245,  91.79974093,  89.94817809,
         88.13495265,  86.37829884,  84.67090725,  83.05548799,
         81.46903821,  79.94032864,  78.44072613,  77.00488358,
         75.62191062,  74.24138916,  72.92121361,  71.66007955,
         70.41908351,  69.23718699,  68.06376522,  66.92736292,
         65.81928473,  64.73408044,  63.67274508,  62.65390845,
         61.66069004,  60.69802867,  59.72997524,  58.7870726 ,
         57.88241178,  57.01307807,  56.14014094,  55.30247271,
         54.48416963,  53.69873843,  52.91791742,  52.16280788,
         51.42670887,  50.67668428]),
test_mae_mean:
 array([127.83738044, 124.65000719, 121.72020148, 118.90983369,
        116.24294452, 113.69376675, 111.29388701, 108.94996321,
        106.7553701 , 104.62401193, 102.5608943 , 100.68648486,
         98.76550219,  96.97546939,  95.20893969,  93.55259092,
         91.9299438 ,  90.3413075 ,  88.76142948,  87.34316226,
         85.96043718,  84.62054143,  83.30115705,  82.07698107,
         80.89857637,  79.67939585,  78.52190061,  77.37787457,
         76.28248431,  75.24121599,  74.2093299 ,  73.21873113,
         72.19303325,  71.23265487,  70.33854865,  69.42902278,
         68.57191177,  67.73459769,  66.88130101,  66.05978781,
         65.26603807,  64.46357751,  63.70019472,  62.95398889,
         62.25243534,  61.56164243,  60.88819753,  60.20476192,
         59.55280602,  58.88520627])
%%time

params_xgb = params.copy() # 修改參數
num_round = params_xgb['n_estimators']
params_xgb['eta'] = params['learning_rate']
del params_xgb['n_estimators']
del params_xgb['learning_rate']

# xgboost原生接口 進行交叉驗證
res = xgb.cv(params_xgb, xgb.DMatrix(X, y), num_round, folds=cv, metrics='mae')
print(res)
    train-mae-mean  train-mae-std  test-mae-mean  test-mae-std
0       127.568312       0.315528     127.837350      1.243183
1       124.176437       0.300477     124.649957      1.236916
2       120.962018       0.301030     121.720238      1.206761
3       117.938005       0.278763     118.909902      1.231662
4       115.061696       0.269224     116.242946      1.190097
5       112.287560       0.240412     113.693771      1.159047
6       109.609152       0.262167     111.293890      1.099815
7       107.072640       0.242916     108.949971      1.067070
8       104.635579       0.209314     106.755350      1.080068
9       102.278841       0.195815     104.624013      1.054731
10       99.978919       0.201804     102.560906      1.055403
11       97.825169       0.213528     100.686517      1.033271
12       95.739612       0.202356      98.765524      1.029646
13       93.719107       0.187538      96.975470      1.005893
14       91.799744       0.175199      95.208905      1.046983
15       89.948177       0.160738      93.552597      1.067333
16       88.134976       0.144838      91.929965      1.052541
17       86.378351       0.163211      90.341278      1.037858
18       84.670908       0.187184      88.761414      0.995875
19       83.055446       0.171080      87.343141      0.981363
20       81.469022       0.164968      85.960420      0.993623
21       79.940317       0.167554      84.620523      0.963820
22       78.440726       0.154343      83.301137      1.004986
23       77.004854       0.141827      82.076961      0.986129
24       75.621930       0.150028      80.898605      0.964261
25       74.241496       0.154140      79.679413      0.949695
26       72.921170       0.140105      78.521875      0.946750
27       71.660085       0.130937      77.377856      0.924869
28       70.419052       0.109023      76.282506      0.928389
29       69.237167       0.107013      75.241214      0.900845
30       68.063844       0.097079      74.209323      0.900476
31       66.927363       0.091163      73.218730      0.942131
32       65.819266       0.091109      72.193025      0.930880
33       64.734090       0.092792      71.232658      0.908819
34       63.672701       0.086543      70.338522      0.932795
35       62.653945       0.088487      69.429022      0.927500
36       61.660666       0.082703      68.571904      0.915664
37       60.697992       0.119144      67.734601      0.882644
38       59.729960       0.126423      66.881299      0.886910
39       58.787107       0.117820      66.059784      0.897685
40       57.882377       0.125402      65.266035      0.877481
41       57.013075       0.109192      64.463574      0.901940
42       56.140131       0.140454      63.700203      0.888990
43       55.302481       0.148805      62.953973      0.834368
44       54.484136       0.145519      62.252445      0.829440
45       53.698661       0.132748      61.561636      0.854725
46       52.917877       0.124366      60.888204      0.875071
47       52.162859       0.133974      60.204764      0.878531
48       51.426765       0.140143      59.552805      0.892451
49       50.676675       0.133987      58.885213      0.873657
Wall time: 2.25 s

validation_curve用了57s,而xgboost.cv只用了2.25s,差距巨大!

# 分類數據集
X, y = make_classification(n_samples=10000, n_features=10, n_classes=2)
n_estimators = 50
params = {'n_estimators':n_estimators, 'booster':'gbtree', 'max_depth':5, 'learning_rate':0.05,
          'objective':'binary:logistic', 'subsample':1, 'colsample_bytree':1}
clf = xgb.XGBClassifier(**params)
cv = KFold(n_splits=5, shuffle=True, random_state=100)
print('test_score:', cross_val_score(clf, X, y, cv=cv, scoring='accuracy'))
test_score: array([0.913 , 0.9235, 0.8955, 0.9075, 0.918 ])
cross_validate(clf, X, y, cv=cv, scoring='accuracy', 
               return_train_score=True)
{'fit_time': array([0.43403697, 0.43297029, 0.41813326, 0.42408895, 0.42200208]),
 'score_time': array([0.00299048, 0.00203776, 0.00500631, 0.0019989 , 0.00299263]),
 'test_score': array([0.913 , 0.9235, 0.8955, 0.9075, 0.918 ]),
 'train_score': array([0.92425 , 0.921125, 0.9285  , 0.92325 , 0.922125])}
%%time

estimator_range = range(1, n_estimators+1)
train_score, test_score = validation_curve(
    clf, X, y, param_name='n_estimators', param_range=estimator_range,
    cv=cv, scoring='accuracy'
)

print('train_score:',train_score[-1])
print('test_score:', test_score[-1])
train_score: [0.92425  0.921125 0.9285   0.92325  0.922125]
test_score: [0.913  0.9235 0.8955 0.9075 0.918 ]
Wall time: 58.7 s
print('train_mae_mean:\n', np.abs(train_score).mean(axis=1))
print('test_mae_mean:\n', np.abs(test_score).mean(axis=1))
train_score.mean(axis=1), test_score.mean(axis=1)
train_mae_mean:
 array([0.912775, 0.916075, 0.91585 , 0.91695 , 0.917125, 0.917225,
        0.91725 , 0.9175  , 0.91745 , 0.917925, 0.91755 , 0.918025,
        0.917975, 0.91835 , 0.918225, 0.918625, 0.919   , 0.91905 ,
        0.918975, 0.9191  , 0.91955 , 0.919525, 0.9198  , 0.9199  ,
        0.919975, 0.920025, 0.9201  , 0.92005 , 0.920125, 0.9208  ,
        0.921425, 0.9218  , 0.921875, 0.922025, 0.922125, 0.9221  ,
        0.92225 , 0.922275, 0.922275, 0.92235 , 0.9226  , 0.9229  ,
        0.923   , 0.9233  , 0.923375, 0.923275, 0.923325, 0.9234  ,
        0.923675, 0.92385 ]),
test_mae_mean:
 array([0.9049, 0.9072, 0.9082, 0.9085, 0.9087, 0.9084, 0.9082, 0.9091,
        0.9087, 0.9089, 0.9091, 0.9092, 0.9089, 0.9101, 0.9102, 0.9108,
        0.9102, 0.9107, 0.9105, 0.9109, 0.9104, 0.9102, 0.9109, 0.9109,
        0.9103, 0.9105, 0.9105, 0.9103, 0.9106, 0.9111, 0.9121, 0.9124,
        0.9124, 0.9122, 0.9119, 0.912 , 0.912 , 0.9117, 0.9114, 0.911 ,
        0.911 , 0.9113, 0.9111, 0.9107, 0.9108, 0.911 , 0.9109, 0.9113,
        0.9114, 0.9115])
%%time

params_xgb = params.copy()
num_round = params_xgb['n_estimators']
params_xgb['eta'] = params['learning_rate']
del params_xgb['n_estimators']
del params_xgb['learning_rate']

res = xgb.cv(params_xgb, xgb.DMatrix(X, y), num_round, folds=cv, metrics='error')
Wall time: 2.37 s
res['train-error-mean'] = 1 - res['train-error-mean']
res['test-error-mean'] = 1 - res['test-error-mean']
print(res)
    train-error-mean  train-error-std  test-error-mean  test-error-std
0           0.912775         0.002296           0.9049        0.007493
1           0.916075         0.003749           0.9072        0.007679
2           0.915850         0.003048           0.9082        0.006615
3           0.916950         0.002090           0.9085        0.008503
4           0.917125         0.002028           0.9087        0.008606
5           0.917225         0.002191           0.9084        0.009356
6           0.917250         0.002219           0.9082        0.009114
7           0.917500         0.002318           0.9091        0.009672
8           0.917450         0.002308           0.9087        0.009405
9           0.917925         0.002467           0.9089        0.009410
10          0.917550         0.002248           0.9091        0.009313
11          0.918025         0.002384           0.9092        0.009389
12          0.917975         0.002583           0.9089        0.009124
13          0.918350         0.002095           0.9101        0.008840
14          0.918225         0.002223           0.9102        0.008658
15          0.918625         0.002204           0.9108        0.008388
16          0.919000         0.002904           0.9102        0.009495
17          0.919050         0.002639           0.9107        0.008376
18          0.918975         0.002451           0.9105        0.008562
19          0.919100         0.002613           0.9109        0.008645
20          0.919550         0.003244           0.9104        0.008570
21          0.919525         0.003234           0.9102        0.008761
22          0.919800         0.003307           0.9109        0.008505
23          0.919900         0.003537           0.9109        0.008505
24          0.919975         0.003535           0.9103        0.008376
25          0.920025         0.003365           0.9105        0.008087
26          0.920100         0.003451           0.9105        0.008390
27          0.920050         0.003514           0.9103        0.008412
28          0.920125         0.003521           0.9106        0.007908
29          0.920800         0.003303           0.9111        0.008351
30          0.921425         0.002912           0.9121        0.009330
31          0.921800         0.002910           0.9124        0.009330
32          0.921875         0.002739           0.9124        0.009330
33          0.922025         0.002837           0.9122        0.009405
34          0.922125         0.002860           0.9119        0.009957
35          0.922100         0.002807           0.9120        0.009497
36          0.922250         0.002777           0.9120        0.009370
37          0.922275         0.002636           0.9117        0.009569
38          0.922275         0.002540           0.9114        0.009609
39          0.922350         0.002477           0.9110        0.009680
40          0.922600         0.002607           0.9110        0.009633
41          0.922900         0.002838           0.9113        0.010033
42          0.923000         0.002787           0.9111        0.009805
43          0.923300         0.002612           0.9107        0.009516
44          0.923375         0.002614           0.9108        0.009667
45          0.923275         0.002897           0.9110        0.009772
46          0.923325         0.002718           0.9109        0.009728
47          0.923400         0.002685           0.9113        0.009811
48          0.923675         0.002820           0.9114        0.009764
49          0.923850         0.002551           0.9115        0.009597

validation_curve用了58.7s,而xgboost.cv只用了2.37s,差距巨大!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章