spark-GBDTs源碼解析(GBDT梯度提升決策樹[迴歸GBTClassifier|分類GBDTRegressor])_(spark_2.2.0)

GBDT算法簡介

【概述】

            GBDT(全稱梯度下降樹)是集成學習中的其中一種算法。幸運的是spark在MLlib中有相關實現,共有兩種實現GBTClassifier,GBDTRegressor。

【spark實現計算流程】

       1. 若當前實現爲GBTClassifier,檢查訓練集的label是否包含0和1之外的值,如果包含異常退出,否則將0和1轉換成-1和+1。若當前時限爲GBDTRegressor,數據不做處理。

       2.根據不同實現配置不同的損失函數和純度計算函數

  GBTClassifier GBDTRegressor
損失函數(loss) L1 ,L2 Logloss
純度計算(impurity) 基尼係數 label列方差

       3.啓發式訓練第一個迴歸樹模型,並設置其權重爲1   

       4.預測測試集的label

       4.調整訓練數據集的label值= -loss.gradient(pred, point.label) 【注】gradient和loss函數綁定下面章節會有講解

       5.將調整label值後的訓練數據,傳入迴歸樹訓練器訓練模型得到模型,設置當前模型權重(weight)=步長(stepSize)

       6.根據訓練模型預測數據:預測結果=上次迭代模型預測結果 + 當前樹模型預測結果 * 當前權重(步長))

       7.重複4-6流程,直到訓練次數達到配置的最大迭代次數

       8.返回樹模型數組和各個模型權重

【注】正式預測過程中,GBTClassifier會將預測結果重新轉換爲0和1(後續代碼會有展示)

調用樣例

 

    val gbtClassfier = new GBTClassifier()
      /*設置目標列*/
      .setLabelCol("")
      /*設置特徵列*/
      .setFeaturesCol("")
      /*設置損失函數類型,僅支持Logistic方式*/
      .setLossType("")
      /*設置最大深度*/
      .setMaxDepth("")
      /*設置純度度量函數*/
      .setImpurity("")
      /*爲避免driver端DAG過長,對driver棧空間壓力過大以及容錯壓力,需要定次checkpoint清空DAG和中間數據持久化*/
      .setCheckpointInterval(10)
      /*最大迭代次數即最終計算隨機森林的個數*/
      .setMaxIter("")
      .setCacheNodeIds("")
      .setMaxBins("")
      .setMaxMemoryInMB("")
      .setMinInfoGain("")
      .setMinInstancesPerNode("")
      .setSeed(31D)
      .setStepSize(0.0)
      .setSubsamplingRate(0.0)

    val model: GBTClassificationModel = gbtClassfier.fit(null:DataFrame)
    model.transform(null:DataFrame)

 

損失函數

損失函數共有兩種類別:

       1.基於迴歸思想實現的GBDT損失函數被封裝在GBTClassifierParams中,僅支持logistic。

       2.基於分類思想實現的GBDT損失函數被封裝在GBTRegressorParams中,支持sequared(L2正則化)和absolution(L1正則化)兩種計算方式。

1.​​​​分類相關損失函數實現

  【損失函數判定和實例化代碼】

private[ml] object GBTClassifierParams { 
  /** 基於分類的實現僅支持:logistic計算類型 */
  final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
//以上將LogLoss重命名爲OldLogLoss
...

override private[ml] def getOldLossType: OldLoss = {
    getLossType match {
      case "logistic" => OldLogLoss
      case _ =>
        // Should never happen because of check in setter method.
        throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
    }
  }

【關於OldLogLoss的實現】

    LogLoss中封裝了,梯度計算和損失值的計算

object LogLoss extends Loss {

  /**
   *梯度計算,用於每次迭代前生成新的label
   * Method to calculate the loss gradients for the gradient boosting calculation for binary 
   * classification
   * The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x)))
   * @param prediction Predicted label.
   * @param label True label.
   * @return Loss gradient
   */
  @Since("1.2.0")
  override def gradient(prediction: Double, label: Double): Double = {
    - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
  }
  /*計算預測誤差*/
  override private[spark] def computeError(prediction: Double, label: Double): Double = {
    val margin = 2.0 * label * prediction
    // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
    2.0 * MLUtils.log1pExp(-margin)
  }
}

 


2.迴歸相關損失函數

【損失函數判定和實例化代碼】

private[ml] object GBTRegressorParams {
  // The losses below should be lowercase.
  /** Accessor for supported loss settings: squared (L2), absolute (L1) */
  final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
...

 override private[ml] def getOldLossType: OldLoss = {
    getLossType match {
      /*L2正則化*/
      case "squared" => OldSquaredError
      /*L1正則化*/
      case "absolute" => OldAbsoluteError
      case _ =>
        // Should never happen because of check in setter method.
        throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
    }
  }

【squared實現】:L2正則化

object SquaredError extends Loss {

  /**
   * Method to calculate the gradients for the gradient boosting calculation for least
   * squares error calculation.
   * The gradient with respect to F(x) is: - 2 (y - F(x))
   * @param prediction Predicted label.
   * @param label True label.
   * @return Loss gradient
   */
  @Since("1.2.0")
  override def gradient(prediction: Double, label: Double): Double = {
    - 2.0 * (label - prediction)
  }

  override private[spark] def computeError(prediction: Double, label: Double): Double = {
    val err = label - prediction
    err * err
  }
}

【absolute實現】:L1正則化

object AbsoluteError extends Loss {

  /**
   * Method to calculate the gradients for the gradient boosting calculation for least
   * absolute error calculation.
   * The gradient with respect to F(x) is: sign(F(x) - y)
   * @param prediction Predicted label.
   * @param label True label.
   * @return Loss gradient
   */
  @Since("1.2.0")
  override def gradient(prediction: Double, label: Double): Double = {
    if (label - prediction < 0) 1.0 else -1.0
  }

  override private[spark] def computeError(prediction: Double, label: Double): Double = {
    val err = label - prediction
    math.abs(err)
  }
}

 

列選擇度量函數(列純度測度)

   【實現方式】 默認情況下:

               GBDT分類實現使用基尼係數作爲列選擇度量函數

               GBDT迴歸實現使用(label列)方差作爲列選擇度量函數

    【注】以上兩種列選擇度量函數不可修改。如需自定義度量函數可以通過修改如下如下源碼,打包到工程文件並配置(spark.driver.userClassPathFirst=true,spark.executor.userClassPathFirst=true)即可完成純度測度函數的替換。

 以下爲算法綁定代碼實現:

  def defaultStrategy(algo: Algo): Strategy = algo match {
    //若當前爲GBDT分類實現,在策略中將Gini作爲純度度量
    case Algo.Classification =>
      new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
        numClasses = 2)
    //若當前爲GBDT分類實現,在策略中將Variance作爲純度度量
    case Algo.Regression =>
      new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
        numClasses = 0)
  }

1.基尼係數

   基尼係數共有兩種計算方式,

   (1).對於給定特徵各個類別概率值的情況下,基尼係數計算方式爲:

    Gini(V) =1 -\sum_{k=1}^{K}{p_{k}}^{2}

   (2).對於未給定特徵各個類別概率值的情況下,基尼係數計算方式爲:

      Gini(D)=1-\sum_{k=1}^{K}\left ( \frac{|{C_{k}}^{}|}{|D|} \right ) ^{2}

 【注】當前spark默認實現爲第二種算法

object Gini extends Impurity {

  /**
   * :: DeveloperApi ::
   * information calculation for multiclass classification
   * @param counts Array[Double] with counts for each label
   * @param totalCount sum of counts for all labels
   * @return information value, or 0 if totalCount = 0
   */
  @Since("1.1.0")
  @DeveloperApi
  override def calculate(counts: Array[Double], totalCount: Double): Double = {
    if (totalCount == 0) {
      return 0
    }
    val numClasses = counts.length
    var impurity = 1.0
    var classIndex = 0
    while (classIndex < numClasses) {
      val freq = counts(classIndex) / totalCount
      impurity -= freq * freq
      classIndex += 1
    }
    impurity
  }

2.方差(label列)實現代碼

object Variance extends Impurity {
  /**
   * :: DeveloperApi ::
   * variance calculation
   * @param count number of instances
   * @param sum sum of labels
   * @param sumSquares summation of squares of the labels
   * @return information value, or 0 if count = 0
   */
  @Since("1.0.0")
  @DeveloperApi
  override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
    if (count == 0) {
      return 0
    }
    val squaredLoss = sumSquares - (sum * sum) / count
    squaredLoss / count
  }

 

模型訓練實現部分

【概述】

          在模型訓練過程中,分類和迴歸模型訓練實現都是調用GradientBoostedTrees.run(...),返回多個迴歸決策樹和各個樹對應的權重。然後在將他們分別封裝成GBTRegressionModel和GBTClassfierModel。

在數據準備階段,分類實現會檢查訓練數據的label列是否會有非0,1數據,若出現將異常退出。

【GBTRegression】數據準備,超參封裝,以及訓練模型代碼  調度相關源碼實現和源碼註釋

 override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
    /*
     * 獲取列的基元個數,主要通過判斷每列有無做過分桶或者二分類處理
     * 例如:若做過分桶處理,分桶個數就是Map中的Value,key爲field下標.若做個二分類相應value值就爲2
     */
    val categoricalFeatures: Map[Int, Int] =
      MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
    /*根據配置的labelCol和featrueCol將RDD中的行數據分裝成LabelPoint*/
    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
    /*獲取特徵列個數*/
    val numFeatures = oldDataset.first().features.size
    /*封裝默認訓練策略(數據純度,損失函數,最大深度,迭代次數等等)*/
    val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
    /*初始化 日誌和計算指標(性能耗時)收集器*/
    val instr = Instrumentation.create(this, oldDataset)
    instr.logParams(params: _*)
    instr.logNumFeatures(numFeatures)
    /*開始梯度提升訓練,訓練過程分類和迴歸的訓練函數一致,並做參數,label數據微調*/
    val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
      $(seed))
    /*將訓練出的迴歸樹模型和各個模型權重以及特徵個數(與測試驗證用)封裝成模型對象*/
    val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
    /*輸出成功日誌*/
    instr.logSuccess(m)
    m
  }

【GBTClassification】數據準備,超參封裝,以及訓練模型代碼  調度相關源碼實現和源碼註釋

override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
    /*和迴歸實現方式一致,計算各列的基元數*/
    val categoricalFeatures: Map[Int, Int] =
      MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
    // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
    // 2 classes now.  This lets us provide a more precise error message.
    /*檢查label列是否包含[0|1]之外的值,若label出現[0|1]之外的值將終止計算,異常退出*/
    val oldDataset: RDD[LabeledPoint] =
      dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
        case Row(label: Double, features: Vector) =>
          require(label == 0 || label == 1, s"GBTClassifier was given" +
            s" dataset with invalid label $label.  Labels must be in {0,1}; note that" +
            s" GBTClassifier currently only supports binary classification.")
          LabeledPoint(label, features)
      }
    /*和迴歸算法實現一致,獲取特徵列個數*/
    val numFeatures = oldDataset.first().features.size
    /*和迴歸算法一致,封裝計算策略,包含純度測度等封裝*/
    val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
    /*和迴歸算法一致,封裝日誌和性能指標相關測量函數*/
    val instr = Instrumentation.create(this, oldDataset)
    instr.logParams(params: _*)
    instr.logNumFeatures(numFeatures)
    instr.logNumClasses(2)
    /*和迴歸實現一致,開始訓練模型,此處列選擇純度測度和其他差異算法,已經在boostingStrategy中差異化封裝完成*/
    val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
      $(seed))
    /*將訓練得出迴歸樹和每棵樹的權重封裝成GBTClassificationModel*/
    val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
    instr.logSuccess(m)
    m
  }

 

【GradientBoostedTrees梯度提升樹】實現和源碼註釋

 【概述】在GBDT的兩種實現中在訓練模型環節均調用GradientBoostedTrees.run(...)來訓練模型。

在正式訓練之前,GBDT分類相關實現對訓練數據做了一個封裝,將label列的[0|1]轉換成[-1|1]。在訓練模型時均調用 GradientBoostedTrees.boost(後續展示)來訓練模型。

如下爲GradientBoostedTrees.run相關代碼的實現和註釋:

  def run(
      input: RDD[LabeledPoint],
      boostingStrategy: OldBoostingStrategy,
      seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
    val algo = boostingStrategy.treeStrategy.algo
    algo match {
      case OldAlgo.Regression =>
        GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
      case OldAlgo.Classification =>
        // Map labels to -1, +1 so binary classification can be treated as regression.
        /*爲了分類GBDT算法能夠以迴歸樹的方式計算,將0,1轉換成-1,+1*/
        val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
        GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
          seed)
      case _ =>
        throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
    }
  }

如下爲GradientBoostedTrees.boost模型訓練相關代碼實現和註釋,主要負責訓練樹模型組和模型相關的權重:

 

/**
   * Internal method for performing regression using trees as base learners.
   * @param input training dataset
   * @param validationInput validation dataset, ignored if validate is set to false.
   * @param boostingStrategy boosting parameters
   * @param validate whether or not to use the validation dataset.
   * @param seed Random seed.
   * @return tuple of ensemble models and weights:
   *         (array of decision tree models, array of model weights)
   */
  def boost(
      input: RDD[LabeledPoint],
      validationInput: RDD[LabeledPoint],
      boostingStrategy: OldBoostingStrategy,
      validate: Boolean,
      seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
    val timer = new TimeTracker()
    timer.start("total")
    timer.start("init")

    boostingStrategy.assertValid()

    // Initialize gradient boosting parameters 初始化梯度提升配置的各個參數
    /*獲取最大迭代次數*/
    val numIterations = boostingStrategy.numIterations
    /*申請存放訓練結果(迴歸樹)的數組容器,容量大小爲迭代次數*/
    val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
    /*爲訓練結果模型(迴歸樹)分配權重容器*/
    val baseLearnerWeights = new Array[Double](numIterations)
    /*獲取損失函數實現,迴歸爲(L1,L2),分類爲logLoss 實現見前面【損失函數實現章節】*/
    val loss = boostingStrategy.loss
    /*獲取學習率(步長默認0.1)*/
    val learningRate = boostingStrategy.learningRate
    // Prepare strategy for individual trees, which use regression with variance impurity. 提取單次迭代數的策略
    val treeStrategy = boostingStrategy.treeStrategy.copy
    val validationTol = boostingStrategy.validationTol
    treeStrategy.algo = OldAlgo.Regression
    treeStrategy.impurity = OldVariance
    treeStrategy.assertValid()

    // Cache input 由於input(RDD)會多次迭代使用,爲避免重複計算前面DAG,緩存數據
    val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
      input.persist(StorageLevel.MEMORY_AND_DISK)
      true
    } else {
      false
    }

    // Prepare periodic checkpointers,中間數據持久化,清空之前DAG
    val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
      treeStrategy.getCheckpointInterval, input.sparkContext)
    val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
      treeStrategy.getCheckpointInterval, input.sparkContext)

    timer.stop("init")

    logDebug("##########")
    logDebug("Building tree 0")
    logDebug("##########")

    // Initialize tree,DGDT爲啓發式計算,先計算第一個迴歸樹模型,默認給予1.0權重
    timer.start("building tree 0")
    val firstTree = new DecisionTreeRegressor().setSeed(seed)
    val firstTreeModel = firstTree.train(input, treeStrategy)
    val firstTreeWeight = 1.0
    baseLearners(0) = firstTreeModel
    baseLearnerWeights(0) = firstTreeWeight
    /*預測數據,並根據不同實現方式和傳入的損失函數,計算預測誤差。計算方式見前面章節【損失函數實現】*/
    var predError: RDD[(Double, Double)] =
      computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
    predErrorCheckpointer.update(predError)
    /*輸出預測誤差均值*/
    logDebug("error of gbt = " + predError.values.mean())

    // Note: A model of type regression is used since we require raw prediction
    timer.stop("building tree 0")
    /*預測驗證集label,並根據loss函數計算誤差*/
    var validatePredError: RDD[(Double, Double)] =
      computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
    if (validate) validatePredErrorCheckpointer.update(validatePredError)
    /*計算誤差均值*/
    var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
    /*初始化最佳模型樹下標*/
    var bestM = 1

    var m = 1
    /*是否提前終止迭代*/
    var doneLearning = false
    while (m < numIterations && !doneLearning) {
      // Update data with pseudo-residuals
      /*將上次預測的結果和label 取梯度的反方向,作爲當前迭代的label值,梯度算法見前面章節【損失函數】*/
      val data = predError.zip(input).map { case ((pred, _), point) =>
        LabeledPoint(-loss.gradient(pred, point.label), point.features)
      }

      timer.start(s"building tree $m")
      logDebug("###################################################")
      logDebug("Gradient boosting tree iteration " + m)
      logDebug("###################################################")
      /*初始化迴歸決策樹並訓練模型*/
      val dt = new DecisionTreeRegressor().setSeed(seed + m)
      val model = dt.train(data, treeStrategy)
      timer.stop(s"building tree $m")
      // Update partial model
      /*將訓練的模型,放入模型容器*/
      baseLearners(m) = model
      // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
      //       Technically, the weight should be optimized for the particular loss.
      //       However, the behavior should be reasonable, though not optimal.
      /* 學習率(步長)作爲當前模型權重,後續會根據學習率(步長)計算預測值
       * (預測結果=上一個樹模型預測結果 + 當前樹模型預測結果 * 當前權重(步長))
       */
      baseLearnerWeights(m) = learningRate
      /*根據訓練出的迴歸樹模型,做預測(預測結果=上一個樹模型預測結果 + 當前樹模型預測結果 * 當前權重(步長)),並根據配置的loss函數計算預測誤差*/
      predError = updatePredictionError(
        input, predError, baseLearnerWeights(m), baseLearners(m), loss)
      predErrorCheckpointer.update(predError)
      logDebug("error of gbt = " + predError.values.mean())
      //爲避免過擬合,是否提前終止計算,當前默認爲false,且不可修改,當前算法實現,如下代碼將不執行
      if (validate) {
        // Stop training early if
        // 1. Reduction in error is less than the validationTol or
        // 2. If the error increases, that is if the model is overfit.
        // We want the model returned corresponding to the best validation error.
        /*預測驗證集的label,並計算預測誤差值,*/
        validatePredError = updatePredictionError(
          validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
        validatePredErrorCheckpointer.update(validatePredError)
        /*計算驗證集誤差期望*/
        val currentValidateError = validatePredError.values.mean()
        /*默認情況:validationTol -> 1e-5 ,若最好模型誤差期望和當前預測誤差期望差值小於某定製,將提前終止計算*/
        if (bestValidateError - currentValidateError < validationTol * Math.max(
          currentValidateError, 0.01)) {
          doneLearning = true
        } else if (currentValidateError < bestValidateError) {
          /*若當前模型誤差期望小於最好模型誤差期望,當前模型下標作爲最佳模型的下標(標記當前模型爲最好模型)*/
          bestValidateError = currentValidateError
          bestM = m + 1
        }
      }
      m += 1
    }

    timer.stop("total")

    logInfo("Internal timing for DecisionTree:")
    logInfo(s"$timer")
    /*刪除所有持久化的中間數據*/
    predErrorCheckpointer.deleteAllCheckpoints()
    validatePredErrorCheckpointer.deleteAllCheckpoints()
    if (persistedInput) input.unpersist()
    /*返回模型樹數組和各個模型的權重(出了第一個爲1,其餘的值和步長相同)*/
    if (validate) {
      /*若開啓了提前終止計算,刪除結果模型容器中多餘的空位*/
      (baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM))
    } else {
      (baseLearners, baseLearnerWeights)
    }
  }

 

預測

【迴歸實現】

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
    /*廣播模型變量*/
    val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
    /*實現預測相關UDF*/
    val predictUDF = udf { (features: Any) =>
      /*調用下面函數進行預測*/
      bcastModel.value.predict(features.asInstanceOf[Vector])
    }
    /*將預測結果作爲新的一列拼接到當前DataFrame*/
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
  } 
 override protected def predict(features: Vector): Double = {
    // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
    // Classifies by thresholding sum of weighted tree predictions
    /*計算每棵樹的預測結果*/
    val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
    /*將每棵樹的計算結果和相關權重做ddot計算*/
    blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
  }

【分類實現】

 override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
    /*廣播模型變量*/
    val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
    /*實現預測的UDF*/
    val predictUDF = udf { (features: Any) =>、
      /*調用下面的函數進行預測*/
      bcastModel.value.predict(features.asInstanceOf[Vector])
    }
    /*將預測結果作爲新的一列拼接到當前DataFrame*/
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
  }

  override protected def predict(features: Vector): Double = {
    // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
    // Classifies by thresholding sum of weighted tree predictions
    /*獲取每顆模型數的預測結果*/
    val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
    /*將每顆樹模型的預測結果和樹模型的權重做ddot計算,得出一個[-1,1]的值*/
    val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
    /*由於模型訓練期間已經將預測結果範圍調整到[-1,+1],將預測結果轉換成[0,1]*/
    if (prediction > 0.0) 1.0 else 0.0
  }

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章