XGBoost是一個優化的分佈式梯度增強庫,具有高效、靈活和可移植性。在梯度增強框架下實現了機器學習算法。XGBoost提供了一種並行樹增強(也稱爲GBDT、GBM),可以快速、準確地解決許多數據科學問題。相同的代碼在主要的分佈式環境(Hadoop、SGE、MPI)上運行,可以解決數十億個示例的訓練問題。
XGBoost可以使用R、python、java、scala實現,本文主要講解採用scala+spark的實現方式。
1.maven環境配置
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.90</version>
</dependency>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark</artifactId>
<version>0.90</version>
</dependency>
2.使用分類方法訓練xbg模型
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
val xgbParam = Map("eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> 2)
val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features").
setLabelCol("classIndex")
在XGBoost4J-Spark中,不僅支持默認的參數集,而且還支持這些參數的大小寫變體,以保持與Spark的MLLIB參數的一致性。
在設置好XGBoostClassifier參數和feature/label列之後,可以通過將分類器與輸入數據流進行擬合,來構建一個轉換器XGBoostClassificationModel。這種擬合操作本質上就是訓練過程,生成的模型可以用於預測。
val xgbClassificationModel = xgbClassifier.fit(xgbInput)
運行中遇到的坑:
1.程序一直在運行中,無法完成
出現這種情況就是你在初始化spark master的時候給的線程數小於你的work_number,切記:
master('local[m]')
work_number(n)
一定要 m >= n
測試環境下一般work_number設置爲1 即可
2.XGBoostModel training failed
ml.dmlc.xgboost4j.java.XGBoostError: XGBoostModel training failed
at ml.dmlc.xgboost4j.scala.spark.XGBoost$.ml$dmlc$xgboost4j$scala$spark$XGBoost$$postTrackerReturnProcessing(XGBoost.scala:582)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$2.apply(XGBoost.scala:459)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$2.apply(XGBoost.scala:435)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:245)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:245)
at scala.collection.immutable.List.foreach(List.scala:383)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:245)
at scala.collection.immutable.List.map(List.scala:286)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$.trainDistributed(XGBoost.scala:434)
at ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier.train(XGBoostClassifier.scala:194)
這種情況一般是由於spark與xgboost4j的版本不匹配導致的,xgboost 9.0 必須對應spark 2.4以上版本,xgboost 8.1 必須對應spark 2.31以上版本。修改版本後程序正確運行。
參考官方文檔:https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html