Spark ML

使用LogisticRegression處理多分類問題

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.{StringIndexer,VectorIndexer,IndexToString}
import org.apache.spark.ml.classification.{LogisticRegression,LogisticRegressionModel}
import org.apache.spark.ml.{Pipeline,PipelineModel}
import org.apache.spark.ml.tuning.{ParamGridBuilder,CrossValidator}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
object classificationModel {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[1]").appName("spark").getOrCreate()
    val sc = spark.sparkContext
    //以RDD方式載入數據並創建數據框
    val rowRDD = sc.textFile("hdfs://localhost:9000/dataset/iris.txt")
                 .map(s => s.split(","))
                 .map(s => Row(s(0).toDouble,s(1).toDouble,s(2).toDouble,s(3).toDouble,s(4).toDouble))
    val schema = StructType(List(
      StructField("v1",DoubleType,nullable = true),StructField("v2",DoubleType,nullable = true),
      StructField("v3",DoubleType,nullable = true),StructField("v4",DoubleType,nullable = true),
      StructField("labels",DoubleType,nullable = true)
    ))
    val df = spark.createDataFrame(rowRDD,schema)
    //將特徵規約爲特徵集合
    val vectorAssembler = new VectorAssembler().setInputCols(Array("v1","v2","v3","v4"))
      .setOutputCol("features")
    val data = vectorAssembler.transform(df).select("features","labels")
    //劃分數據集
    val Array(train,test) = data.randomSplit(Array(0.8,0.2),seed = 1000)
    train.cache()
    test.cache()
    println("Train size = " + train.count() + " Test size = " + test.count())
    //StringIndexer:將字符串標籤轉爲索引標籤(基於頻數進行編碼)
    val stringIndexer = new StringIndexer().setInputCol("labels").setOutputCol("indexedLabels").fit(data)
    //VectorIndexer:區別特徵類型(連續/離散)並作相應處理
    val vectorIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures")
                                           .setMaxCategories(10).fit(data)
    val logisticRegression = new LogisticRegression().setFeaturesCol("indexedFeatures").setLabelCol("indexedLabels").setMaxIter(20)
    //將索引標籤轉爲字符串標籤
    val indexToString = new IndexToString().setLabels(stringIndexer.labels).setInputCol("prediction").setOutputCol("forecastLabels")
    //建立流水線
    val pipeline = new Pipeline().setStages(Array(stringIndexer,vectorIndexer,logisticRegression,indexToString))
    //設置參數網格搜索
    val paramGrid = new ParamGridBuilder().addGrid(logisticRegression.elasticNetParam,Array(0.2,0.8))
                                          .addGrid(logisticRegression.regParam,Array(0.1,0.5,0.8))
                                          .build()
    //交叉驗證訓練集
    val crossValidator = new CrossValidator().setEstimator(pipeline).setEstimatorParamMaps(paramGrid)
                         .setEvaluator(new MulticlassClassificationEvaluator().setLabelCol("indexedLabels").setPredictionCol("prediction"))
                         .setNumFolds(3)
    val cvModel = crossValidator.fit(train)
    /*
    預測輸出項目說明
      -features: 處理前的特徵集合      indexedFeatures:處理(vectorIndexer)後的特徵集合
      -labels:處理前的字符串標籤      indexedLabels:處理(stringIndexer)後的索引標籤
      -probability:標籤預測概率       rawPrediction:softmax預測值
      -prediction:預測值(索引標籤)    forecastLabels:預測值(字符串標籤)
     */
    val cvTrainResults = cvModel.transform(train)
    val cvTestResults = cvModel.transform(test)
    cvTestResults.show()
    val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabels").setPredictionCol("prediction")
    val train_acc = evaluator.evaluate(cvTrainResults)
    val test_acc = evaluator.evaluate(cvTestResults)
    println("The accuracy of train set: "+train_acc+" The accuracy of test set: "+test_acc)    
    //查看最佳參數(elasticNetParam = 0.2, regParam = 0.1)
    val bestModel = cvModel.bestModel.asInstanceOf[PipelineModel]
    val DtModel = bestModel.stages(2).asInstanceOf[LogisticRegressionModel]
    println("Model params: "+ DtModel.explainParams())

  }
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章