spark 改寫版 電信用戶流失預測

參考鏈接:https://github.com/baopuzi/Telco_Customer_Churn/blob/master/tele_customer_churn_analysis.ipynb
背景:https://zhuanlan.zhihu.com/p/68397317

package LittleTask

/**
 * @CreateUser: eshter
 * @CreateDate: 2019/12/4
 * @UpdateUser:
 */

import org.apache.spark.ml.classification.{DecisionTreeClassifier, LogisticRegression, RandomForestClassifier}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions.{col, mean, udf}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType}
import utils.session.IgnoreErrorAndINFO
import org.apache.spark.ml.feature.{StandardScaler, VectorAssembler, _}
import utils.metrics.Metrics
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}

object telco03 {
      //Logger.getLogger("org").setLevel(Level.ERROR)
      new IgnoreErrorAndINFO().ignoreErrorAndInfo()

      def missValueDistubuteValue(df:DataFrame,value:Char, spark: SparkSession)={
            import spark.implicits._
            val columnNames=df.columns
            val missDF = columnNames.map(co => {
                  (co -> df.select(co).filter(f"$co=' '").count().toString)
            }).toSeq.toDF("列名", "缺失值個數")
            println(missDF.show(21))
            missDF

      }
      def missValueDistubute(df:DataFrame, spark: SparkSession)={
            import spark.implicits._
            val columnNames=df.columns
            val missDF = columnNames.map(co => {
                  (co -> df.select(co).filter(f"$co  is null").count().toString)
            }).toSeq.toDF("列名", "缺失值個數")
            println(missDF.show(21))
            missDF

      }
      def feaStandScaleTransform(df:DataFrame,scaleinputCol:Array[String],scaleOutputCol:String,scaleOutputColFinal:String)={
            val vectorScale = new VectorAssembler()
                .setInputCols(scaleinputCol)
                .setOutputCol(scaleOutputCol)
            val scale=new StandardScaler()
                .setInputCol(scaleOutputCol).setOutputCol(scaleOutputColFinal)
            val pipeline = new Pipeline()
                .setStages(Array(vectorScale,scale))
            val model = pipeline.fit(df)
            val scaleData=model.transform(df)
            scaleData

      }
      def feaOneHotTransform(df:DataFrame,strCol:Array[String],strIndex:Array[String],onehotCol:Array[String])={

            val standardIndex=strCol.map{line=>
                  new StringIndexer().setInputCol(line).setOutputCol(line+"ToInt")
            }
            val onehotEncoder = new OneHotEncoderEstimator()
                .setInputCols(strIndex.filter(co=>co.equals("ChurnToInt")==false))
                .setOutputCols(onehotCol)

            val pipelineOnehotEncoder = new Pipeline()
                .setStages(standardIndex++Array(onehotEncoder))

            val model=pipelineOnehotEncoder.fit(df)
            val dt1=model.transform(df)
            println(dt1.show(10))
            println(dt1.columns.foreach(println(_)))
            dt1

      }
      def modelTrain(dfTrain:DataFrame,valid:DataFrame,feaCol:String,labelCol:String)={

            val lr = new LogisticRegression()
                .setFeaturesCol(feaCol)
                .setLabelCol(labelCol)
                .setProbabilityCol("prob")

            val rf = new RandomForestClassifier()
                .setFeaturesCol(feaCol)
                .setLabelCol(labelCol)
                .setProbabilityCol("prob")
            val dtree = new DecisionTreeClassifier()
                .setFeaturesCol(feaCol)
                .setLabelCol(labelCol)
                .setProbabilityCol("prob")

            val lrParaGrid = new ParamGridBuilder()
                .addGrid(lr.maxIter,Array(50,100,150))
                .build()
            val rfGridParm =new ParamGridBuilder()
                .addGrid(rf.maxBins,Array(4,5,10))
                .addGrid(rf.maxDepth,Array(2,3,4))
                .build()
            val dtreeGridParm = new ParamGridBuilder()
                .addGrid(rf.maxBins,Array(4,5,10))
                .addGrid(dtree.maxDepth,Array(2,3,4))
                .build()
            val models = Array(lr, rf,dtree)
            val paramGrids = Array(lrParaGrid, rfGridParm, dtreeGridParm)
            for (i<-0 until(models.length)) //不包含models.length
                  {
                        val evaluator = new BinaryClassificationEvaluator()
                            .setLabelCol(labelCol)
                            .setRawPredictionCol("prediction") //此處改爲prob 就可以設定閾值
                            .setMetricName("areaUnderROC")
                        val cv = new CrossValidator()
                            .setEstimator(models(i))
                            .setEvaluator(evaluator)
                            .setEstimatorParamMaps(paramGrids(i))
                            .setNumFolds(2)
                            .setParallelism(2) //evaluate up to 2 parameter setting is parallel
                        val cvModel=cv.fit(dfTrain)
                        val finalDf = cvModel.transform(dfTrain)
                        val metr= new Metrics(finalDf, labelCol, "prediction")
                        println("訓練集結果:")
                        metr.metricFunc()
                        val validDF=cvModel.transform(valid)
                        val metr1= new Metrics(validDF, labelCol, "prediction")
                        println("驗證集結果:")
                        metr1.metricFunc()




                  }





      }



      def main(args: Array[String]): Unit = {
            val spark = SparkSession.builder()
                .master("local[2]")
                .appName("winter")
                .getOrCreate()
            import spark.implicits._
            var dt = spark
                .read
                .format("csv")
                .option("header", true)
                .option("inferSchema", true)
                //.option("multiLine",true)
                .option("delimiter", ",")
                .load("/Users/eshter/BikeDmtSparkJob/bikedmtSpark/src/main/scala/bikedmtsparkjob/yufang/LittleTask/WA_Fn-UseC_-Telco-Customer-Churn.csv")
            println("數據集的size="+dt.count()+"\n特徵的個數="+dt.columns.length)
            println(dt.summary().show())
            println(dt.show(10))
            import spark.implicits._
            val colsName = dt.columns
            println(colsName.foreach(println(_)))
            // 轉換數據類型
            val index_col =Array("customerID")
            val str_col =Array(
                  "gender"
                  ,"Partner"
                  ,"Dependents"
                  ,"PhoneService"
                  ,"MultipleLines"
                  ,"InternetService"
                  ,"OnlineSecurity"
                  ,"OnlineBackup"
                  ,"DeviceProtection"
                  ,"TechSupport"
                  ,"StreamingTV"
                  ,"StreamingMovies"
                  ,"Contract"
                  ,"PaperlessBilling"
                  ,"PaymentMethod"
                  ,"Churn"
            )


            val int_col = Array("tenure"
                  ,"SeniorCitizen"

            )
            val label_col = Array("Churn")
            val dou_col = Array("MonthlyCharges"
                  ,"TotalCharges"
            )
            val index_cols = index_col.map(c=>col(c).cast(StringType))
            val int_cols = int_col.map(c => col(c).cast(IntegerType))
            val str_cols = str_col.map(c => col(c).cast(StringType))
            val dou_cols =dou_col.map(c=> col(c).cast(DoubleType))

            val dt_tmp  = dt.select(index_cols++int_cols ++ dou_cols ++str_cols: _*)
            println("輸出數值型的summary()\n"+dt_tmp.select(int_cols ++ dou_cols:_*).summary().show())
            println("輸出類別:\n")
            dt_tmp.getClass
            val missvalue=' '
            missvalue.getClass
            // 缺失值分佈查看
            val missD= missValueDistubuteValue(dt_tmp, missvalue,spark)
            println(missD.show(21))

            //填充缺失值
            val mean_month_vale=dt_tmp.select(mean(col("MonthlyCharges"))).first()(0)
            println("mean_month_vale="+mean_month_vale)
            val dt_tmpp =dt_tmp.na.fill(Map("TotalCharges"->mean_month_vale))

            // 查看填充後的數據的缺失值情況
            val missDF = missValueDistubute(dt_tmpp, spark)
            println(missDF.show(21))


            //standScale只能輸入vector然後在進行標準化
            val scaleInputCol=int_col++dou_col
            val scaleOutputCol ="feaScale"
            val scaleOutputColFinal="sfea"
            val scaleDF=feaStandScaleTransform(dt_tmpp,scaleinputCol=scaleInputCol,scaleOutputCol=scaleOutputCol,scaleOutputColFinal=scaleOutputColFinal)
            println(scaleDF.columns.foreach(println(_)))

            val strIndex=str_col.map(col=>col+"ToInt")
            val onehotCol=str_col.filter(co=>co.equals("Churn")==false).map(col=>col+"ToOneHotTOInt")


            val dtTrainB=feaOneHotTransform(scaleDF,str_col,strIndex,onehotCol)
             //最終拼接成一個完整的向量
            val fVector=new VectorAssembler()
                .setInputCols(onehotCol++Array(scaleOutputColFinal))
                .setOutputCol("features")
            val pipelineF = new Pipeline()
                .setStages(Array(fVector))
            val modelF = pipelineF.fit(dtTrainB)
            val dtTrain=modelF.transform(dtTrainB)
            println(dtTrain.show(2))
            val data = dtTrain.randomSplit(Array(0.8, 0.2), seed = 2)
            modelTrain(data(0),data(1),"features",labelCol="ChurnToInt")








      }

}


發佈了48 篇原創文章 · 獲贊 83 · 訪問量 7萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章