xgboost在spark集群使用指南

    XGBoost是一个优化的分布式梯度增强库,具有高效、灵活和可移植性。在梯度增强框架下实现了机器学习算法。XGBoost提供了一种并行树增强(也称为GBDT、GBM),可以快速、准确地解决许多数据科学问题。相同的代码在主要的分布式环境(Hadoop、SGE、MPI)上运行,可以解决数十亿个示例的训练问题。

  XGBoost可以使用R、python、java、scala实现,本文主要讲解采用scala+spark的实现方式。

1.maven环境配置

<dependency>
   <groupId>ml.dmlc</groupId>
   <artifactId>xgboost4j</artifactId>
   <version>0.90</version>
</dependency>
<dependency>
     <groupId>ml.dmlc</groupId>
     <artifactId>xgboost4j-spark</artifactId>
     <version>0.90</version>
</dependency>

2.使用分类方法训练xbg模型

import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
val xgbParam = Map("eta" -> 0.1f,
      "max_depth" -> 2,
      "objective" -> "multi:softprob",
      "num_class" -> 3,
      "num_round" -> 100,
      "num_workers" -> 2)
val xgbClassifier = new XGBoostClassifier(xgbParam).
      setFeaturesCol("features").
      setLabelCol("classIndex")

在XGBoost4J-Spark中,不仅支持默认的参数集,而且还支持这些参数的大小写变体,以保持与Spark的MLLIB参数的一致性。

在设置好XGBoostClassifier参数和feature/label列之后,可以通过将分类器与输入数据流进行拟合,来构建一个转换器XGBoostClassificationModel。这种拟合操作本质上就是训练过程,生成的模型可以用于预测。

val xgbClassificationModel = xgbClassifier.fit(xgbInput)

运行中遇到的坑:

1.程序一直在运行中,无法完成

出现这种情况就是你在初始化spark master的时候给的线程数小于你的work_number,切记:

master('local[m]')

work_number(n)

一定要 m >= n

测试环境下一般work_number设置为1 即可

2.XGBoostModel training failed

ml.dmlc.xgboost4j.java.XGBoostError: XGBoostModel training failed
  at ml.dmlc.xgboost4j.scala.spark.XGBoost$.ml$dmlc$xgboost4j$scala$spark$XGBoost$$postTrackerReturnProcessing(XGBoost.scala:582)
  at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$2.apply(XGBoost.scala:459)
  at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$2.apply(XGBoost.scala:435)
  at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:245)
  at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:245)
  at scala.collection.immutable.List.foreach(List.scala:383)
  at scala.collection.TraversableLike$class.map(TraversableLike.scala:245)
  at scala.collection.immutable.List.map(List.scala:286)
  at ml.dmlc.xgboost4j.scala.spark.XGBoost$.trainDistributed(XGBoost.scala:434)
  at ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier.train(XGBoostClassifier.scala:194)

这种情况一般是由于spark与xgboost4j的版本不匹配导致的,xgboost 9.0 必须对应spark 2.4以上版本,xgboost 8.1 必须对应spark 2.31以上版本。修改版本后程序正确运行。

参考官方文档:https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章