xgboost在spark集羣使用指南

    XGBoost是一個優化的分佈式梯度增強庫,具有高效、靈活和可移植性。在梯度增強框架下實現了機器學習算法。XGBoost提供了一種並行樹增強(也稱爲GBDT、GBM),可以快速、準確地解決許多數據科學問題。相同的代碼在主要的分佈式環境(Hadoop、SGE、MPI)上運行,可以解決數十億個示例的訓練問題。

  XGBoost可以使用R、python、java、scala實現,本文主要講解採用scala+spark的實現方式。

1.maven環境配置

<dependency>
   <groupId>ml.dmlc</groupId>
   <artifactId>xgboost4j</artifactId>
   <version>0.90</version>
</dependency>
<dependency>
     <groupId>ml.dmlc</groupId>
     <artifactId>xgboost4j-spark</artifactId>
     <version>0.90</version>
</dependency>

2.使用分類方法訓練xbg模型

import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
val xgbParam = Map("eta" -> 0.1f,
      "max_depth" -> 2,
      "objective" -> "multi:softprob",
      "num_class" -> 3,
      "num_round" -> 100,
      "num_workers" -> 2)
val xgbClassifier = new XGBoostClassifier(xgbParam).
      setFeaturesCol("features").
      setLabelCol("classIndex")

在XGBoost4J-Spark中,不僅支持默認的參數集,而且還支持這些參數的大小寫變體,以保持與Spark的MLLIB參數的一致性。

在設置好XGBoostClassifier參數和feature/label列之後,可以通過將分類器與輸入數據流進行擬合,來構建一個轉換器XGBoostClassificationModel。這種擬合操作本質上就是訓練過程,生成的模型可以用於預測。

val xgbClassificationModel = xgbClassifier.fit(xgbInput)

運行中遇到的坑:

1.程序一直在運行中,無法完成

出現這種情況就是你在初始化spark master的時候給的線程數小於你的work_number,切記:

master('local[m]')

work_number(n)

一定要 m >= n

測試環境下一般work_number設置爲1 即可

2.XGBoostModel training failed

ml.dmlc.xgboost4j.java.XGBoostError: XGBoostModel training failed
  at ml.dmlc.xgboost4j.scala.spark.XGBoost$.ml$dmlc$xgboost4j$scala$spark$XGBoost$$postTrackerReturnProcessing(XGBoost.scala:582)
  at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$2.apply(XGBoost.scala:459)
  at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$2.apply(XGBoost.scala:435)
  at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:245)
  at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:245)
  at scala.collection.immutable.List.foreach(List.scala:383)
  at scala.collection.TraversableLike$class.map(TraversableLike.scala:245)
  at scala.collection.immutable.List.map(List.scala:286)
  at ml.dmlc.xgboost4j.scala.spark.XGBoost$.trainDistributed(XGBoost.scala:434)
  at ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier.train(XGBoostClassifier.scala:194)

這種情況一般是由於spark與xgboost4j的版本不匹配導致的,xgboost 9.0 必須對應spark 2.4以上版本,xgboost 8.1 必須對應spark 2.31以上版本。修改版本後程序正確運行。

參考官方文檔:https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章