XGBoost是一个优化的分布式梯度增强库,具有高效、灵活和可移植性。在梯度增强框架下实现了机器学习算法。XGBoost提供了一种并行树增强(也称为GBDT、GBM),可以快速、准确地解决许多数据科学问题。相同的代码在主要的分布式环境(Hadoop、SGE、MPI)上运行,可以解决数十亿个示例的训练问题。
XGBoost可以使用R、python、java、scala实现,本文主要讲解采用scala+spark的实现方式。
1.maven环境配置
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.90</version>
</dependency>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark</artifactId>
<version>0.90</version>
</dependency>
2.使用分类方法训练xbg模型
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
val xgbParam = Map("eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> 2)
val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features").
setLabelCol("classIndex")
在XGBoost4J-Spark中,不仅支持默认的参数集,而且还支持这些参数的大小写变体,以保持与Spark的MLLIB参数的一致性。
在设置好XGBoostClassifier参数和feature/label列之后,可以通过将分类器与输入数据流进行拟合,来构建一个转换器XGBoostClassificationModel。这种拟合操作本质上就是训练过程,生成的模型可以用于预测。
val xgbClassificationModel = xgbClassifier.fit(xgbInput)
运行中遇到的坑:
1.程序一直在运行中,无法完成
出现这种情况就是你在初始化spark master的时候给的线程数小于你的work_number,切记:
master('local[m]')
work_number(n)
一定要 m >= n
测试环境下一般work_number设置为1 即可
2.XGBoostModel training failed
ml.dmlc.xgboost4j.java.XGBoostError: XGBoostModel training failed
at ml.dmlc.xgboost4j.scala.spark.XGBoost$.ml$dmlc$xgboost4j$scala$spark$XGBoost$$postTrackerReturnProcessing(XGBoost.scala:582)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$2.apply(XGBoost.scala:459)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$2.apply(XGBoost.scala:435)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:245)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:245)
at scala.collection.immutable.List.foreach(List.scala:383)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:245)
at scala.collection.immutable.List.map(List.scala:286)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$.trainDistributed(XGBoost.scala:434)
at ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier.train(XGBoostClassifier.scala:194)
这种情况一般是由于spark与xgboost4j的版本不匹配导致的,xgboost 9.0 必须对应spark 2.4以上版本,xgboost 8.1 必须对应spark 2.31以上版本。修改版本后程序正确运行。
参考官方文档:https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html