spark改寫 心血管疾病預測

python版傳送門:https://www.kesci.com/home/project/5da974e9c83fb400420f77d3

package dataclear

/**
 * @CreateUser: eshter
 * @CreateDate: 2019/10/23
 * @UpdateUser:
 */

import utils.session.IgnoreErrorAndINFO
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.ml.classification.{LogisticRegression}
import org.apache.spark.ml.feature.{StandardScaler, VectorAssembler, _}
import utils.metrics.Metrics
import org.apache.spark.ml.Pipeline
object cardioTrainLr {
  /*
  注意:
  1、label =cardio
  2、StandardScaler 只支持輸入向量(org.spark.ml.linalg.Vector)的數據
  3、數據的連續型變量爲Array(
      "age"
      ,"height"
      ,"weight"
      ,"ap_hi"
      ,"ap_lo"
    )
    4、數據的離散型變量爲
    Array(
      "gender"
      ,"cholesterol"
      ,"gluc"
      ,"smoke"
      ,"alco"
    )
   */
  new IgnoreErrorAndINFO().ignoreErrorAndInfo()


  def splitData(df:DataFrame,splitRate:Double)={
    val dfTmp = df.randomSplit(Array(splitRate,1-splitRate),seed=2)
    List(dfTmp(0),dfTmp(1))
  }

  def featureHandleTest(dfTrain:DataFrame,dfValid:DataFrame,featureCols:Array[String])={
    val scale_col=Array(
      "age"
      ,"height"
      ,"weight"
      ,"ap_hi"
      ,"ap_lo"
    )
    val onehot_col=Array(
      "gender"
      ,"cholesterol"
      ,"gluc"
      ,"smoke"
      ,"alco"
    )
    val onehot_colToInt=onehot_col.map(col=>col+"ToInt")
    val standardIndex=onehot_col.map{line=>
      new StringIndexer().setInputCol(line).setOutputCol(line+"ToInt")
    }
    val vectorScale = new VectorAssembler()
      .setInputCols(scale_col)
      .setOutputCol("feaScale")
    val scale=new StandardScaler().setInputCol("feaScale").setOutputCol("sfea")
    val pipeline = new Pipeline().setStages(Array(vectorScale,scale))
    val model = pipeline.fit(dfTrain)
    val scaledfTrain=model.transform(dfTrain)
    val scaleDfTest = model.transform(dfValid)
    val vectorAssembler = new VectorAssembler()
      .setInputCols(onehot_colToInt++Array("sfea"))
      .setOutputCol("features")
    val pipelineFinal = new Pipeline()
      .setStages(standardIndex++Array(vectorAssembler))

    val modelFinal = pipelineFinal.fit(scaledfTrain)
    val scaledfTrain1=modelFinal.transform(scaledfTrain)
    val scaleDfTest1 = modelFinal.transform(scaleDfTest)
    List(scaledfTrain1,scaleDfTest1)
  }


  def modelTrainLr(dfTrain:DataFrame,dfValid:DataFrame,featureCol:String,label:String): Unit ={
    val lr = new LogisticRegression()
      .setLabelCol(label)
      .setFeaturesCol(featureCol)
      .setMaxIter(50)

    val LrModel = lr.fit(dfTrain)
    val predTrain=LrModel.transform(dfTrain)
    val mer = new Metrics(predTrain,label,"prediction")
    mer.metricFunc()
    val predTest=LrModel.transform(dfTrain)
    val mert = new Metrics(predTest,label,"prediction")
    mert.metricFunc()


  }

  def main(args: Array[String]): Unit = {
    //cardio_train.csv
    val spark=SparkSession.builder()
      .master("local[2]")
      .appName("cardio_train")
      .getOrCreate()
    var src_train=spark
      .read
      .format("csv")
      .option("header",true)
      .option("inferSchema",true)
      //.option("multiLine",true)
      .option("delimiter",";")
      .load("/Users/eshter/Desktop/cc_data/cardio_train.csv")
    println(src_train.show(2))
   val  label_col ="cardio"

    //刪除id 列
    src_train=src_train.drop("id")
    //打印分佈情況
    src_train.summary().show()
    val splitRate=0.85

   val  df=splitData(src_train,splitRate)
    val dfTrain=df(0)
    val dfValid  =df(1)
   // println(dfTrain.show(100))
    println(dfTrain.stat.corr("gender",label_col,"pearson").toString)
    val featCols = dfTrain.columns.filter(dfTrain.stat.corr(_,label_col,"pearson").abs > 0.1).filter(_!="cardio")

    val scaleData= featureHandleTest(dfTrain,dfValid,featCols)
    modelTrainLr(scaleData(0),scaleData(1),"features",label_col)


  }

}

發佈了48 篇原創文章 · 獲贊 83 · 訪問量 7萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章