使用LogisticRegression處理多分類問題
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types._ import org.apache.spark.sql.Row import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.feature.{StringIndexer,VectorIndexer,IndexToString} import org.apache.spark.ml.classification.{LogisticRegression,LogisticRegressionModel} import org.apache.spark.ml.{Pipeline,PipelineModel} import org.apache.spark.ml.tuning.{ParamGridBuilder,CrossValidator} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
object classificationModel { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[1]").appName("spark").getOrCreate() val sc = spark.sparkContext //以RDD方式載入數據並創建數據框 val rowRDD = sc.textFile("hdfs://localhost:9000/dataset/iris.txt") .map(s => s.split(",")) .map(s => Row(s(0).toDouble,s(1).toDouble,s(2).toDouble,s(3).toDouble,s(4).toDouble)) val schema = StructType(List( StructField("v1",DoubleType,nullable = true),StructField("v2",DoubleType,nullable = true), StructField("v3",DoubleType,nullable = true),StructField("v4",DoubleType,nullable = true), StructField("labels",DoubleType,nullable = true) )) val df = spark.createDataFrame(rowRDD,schema) //將特徵規約爲特徵集合 val vectorAssembler = new VectorAssembler().setInputCols(Array("v1","v2","v3","v4")) .setOutputCol("features") val data = vectorAssembler.transform(df).select("features","labels") //劃分數據集 val Array(train,test) = data.randomSplit(Array(0.8,0.2),seed = 1000) train.cache() test.cache() println("Train size = " + train.count() + " Test size = " + test.count()) //StringIndexer:將字符串標籤轉爲索引標籤(基於頻數進行編碼) val stringIndexer = new StringIndexer().setInputCol("labels").setOutputCol("indexedLabels").fit(data) //VectorIndexer:區別特徵類型(連續/離散)並作相應處理 val vectorIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures") .setMaxCategories(10).fit(data) val logisticRegression = new LogisticRegression().setFeaturesCol("indexedFeatures").setLabelCol("indexedLabels").setMaxIter(20) //將索引標籤轉爲字符串標籤 val indexToString = new IndexToString().setLabels(stringIndexer.labels).setInputCol("prediction").setOutputCol("forecastLabels") //建立流水線 val pipeline = new Pipeline().setStages(Array(stringIndexer,vectorIndexer,logisticRegression,indexToString)) //設置參數網格搜索 val paramGrid = new ParamGridBuilder().addGrid(logisticRegression.elasticNetParam,Array(0.2,0.8)) .addGrid(logisticRegression.regParam,Array(0.1,0.5,0.8)) .build() //交叉驗證訓練集 val crossValidator = new CrossValidator().setEstimator(pipeline).setEstimatorParamMaps(paramGrid) .setEvaluator(new MulticlassClassificationEvaluator().setLabelCol("indexedLabels").setPredictionCol("prediction")) .setNumFolds(3) val cvModel = crossValidator.fit(train) /* 預測輸出項目說明 -features: 處理前的特徵集合 indexedFeatures:處理(vectorIndexer)後的特徵集合 -labels:處理前的字符串標籤 indexedLabels:處理(stringIndexer)後的索引標籤 -probability:標籤預測概率 rawPrediction:softmax預測值 -prediction:預測值(索引標籤) forecastLabels:預測值(字符串標籤) */ val cvTrainResults = cvModel.transform(train) val cvTestResults = cvModel.transform(test) cvTestResults.show() val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabels").setPredictionCol("prediction") val train_acc = evaluator.evaluate(cvTrainResults) val test_acc = evaluator.evaluate(cvTestResults) println("The accuracy of train set: "+train_acc+" The accuracy of test set: "+test_acc) //查看最佳參數(elasticNetParam = 0.2, regParam = 0.1) val bestModel = cvModel.bestModel.asInstanceOf[PipelineModel] val DtModel = bestModel.stages(2).asInstanceOf[LogisticRegressionModel] println("Model params: "+ DtModel.explainParams())
} }