Spark ML 特徵工程之 One-Hot Encoding

1.什麼是One-Hot Encoding

One-Hot Encoding 也就是獨熱碼,直觀來說就是有多少個狀態就有多少比特,而且只有一個比特爲1,其他全爲0的一種碼制。在機器學習(Logistic Regression,SVM等)中對於離散型的分類型的數據,需要對其進行數字化比如說性別這一屬性,只能有男性或者女性或者其他這三種值,如何對這三個值進行數字化表達?一種簡單的方式就是男性爲0,女性爲1,其他爲2,這樣做有什麼問題?
使用上面簡單的序列對分類值進行表示後,進行模型訓練時可能會產生一個問題就是特徵的因爲數字值得不同影響模型的訓練效果,在模型訓練的過程中不同的值使得同一特徵在樣本中的權重可能發生變化,假如直接編碼成1000,是不是比編碼成1對模型的的影響更大。爲了解決上述的問題,使訓練過程中不受到因爲分類值表示的問題對模型產生的負面影響,引入獨熱碼對分類型的特徵進行獨熱碼編碼。

2.One-Hot Encoding在Spark中的應用

測試數據地址

2.1 數據集預覽

數據中字段含義如下:
affairs:Double //是否有婚外情
gender:String //性別 
age:Double //年齡 
yearsmarried:Double //婚齡 
children:String //是否有小孩 
religiousness:Double //宗教信仰程度(5分制,1分表示反對,5分表示非常信仰)
education:Double //學歷
occupation:Double //職業(逆向編號的戈登7種分類) 
rating:Double //對婚姻的自我評分(5分制,1表示非常不幸福,5表示非常幸福)

2.2 加載數據集

    val conf = new SparkConf().setMaster("local[4]").setAppName(getClass.getSimpleName).set("spark.testing.memory", "2147480000")
    val sparkContext = new SparkContext(conf)
    val sqlContext = new HiveContext(sparkContext)
    val colArray2 = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")
    val logPath = "E:\\spark_workspace\\spark-study\\src\\main\\files\\lr_test03.json"
    import sqlContext.implicits._

    val dataDF = sqlContext.read.json(logPath).select($"affairs", $"gender", $"age", $"yearsmarried", $"children", $"religiousness", $"education", $"occupation", $"rating")
    

2.3 使用OneHotEncoder處理數據集

    /**要進行OneHotEncoder編碼的字段*/
    val categoricalColumns = Array("gender", "children")
    /**採用Pileline方式處理機器學習流程*/
    val stagesArray = new ListBuffer[PipelineStage]()
    for (cate <- categoricalColumns) {
      /**使用StringIndexer 建立類別索引*/
      val indexer = new StringIndexer().setInputCol(cate).setOutputCol(s"${cate}Index")
      /**使用OneHotEncoder將分類變量轉換爲二進制稀疏向量*/
      val encoder = new OneHotEncoder().setInputCol(indexer.getOutputCol).setOutputCol(s"${cate}classVec")
      stagesArray.append(indexer,encoder)
    }

2.4 使用VectorAssembler合併所有特徵爲單個向量

    val numericCols = Array("affairs", "age", "yearsmarried", "religiousness", "education", "occupation", "rating")
    val assemblerInputs = categoricalColumns.map(_ + "classVec") ++ numericCols
    /**使用VectorAssembler將所有特徵轉換爲一個向量*/
    val assembler = new VectorAssembler().setInputCols(assemblerInputs).setOutputCol("features")
    stagesArray.append(assembler)

2.5 以Pipeline的形式運行各個PipelineStage

    val pipeline = new Pipeline()
    pipeline.setStages(stagesArray.toArray)
    /**fit() 根據需要計算特徵統計信息*/
    val pipelineModel = pipeline.fit(dataDF)
    /**transform() 真實轉換特徵*/
    val dataset = pipelineModel.transform(dataDF)
    dataset.show(false)

One-Hot Encoding 之後的數據集結果如下圖:

+-------+------+----+------------+--------+-------------+---------+----------+------+-----------+--------------+-------------+----------------+----------------------------------------+
|affairs|gender|age |yearsmarried|children|religiousness|education|occupation|rating|genderIndex|genderclassVec|childrenIndex|childrenclassVec|features                                |
+-------+------+----+------------+--------+-------------+---------+----------+------+-----------+--------------+-------------+----------------+----------------------------------------+
|0.0    |male  |37.0|10.0        |no      |3.0          |18.0     |7.0       |4.0   |1.0        |(1,[],[])     |1.0          |(1,[],[])       |[0.0,0.0,0.0,37.0,10.0,3.0,18.0,7.0,4.0]|
|0.0    |female|27.0|4.0         |no      |4.0          |14.0     |6.0       |4.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,27.0,4.0,4.0,14.0,6.0,4.0] |
|0.0    |female|32.0|15.0        |yes     |1.0          |12.0     |1.0       |4.0   |0.0        |(1,[0],[1.0]) |0.0          |(1,[0],[1.0])   |[1.0,1.0,0.0,32.0,15.0,1.0,12.0,1.0,4.0]|
|0.0    |male  |57.0|15.0        |yes     |5.0          |18.0     |6.0       |5.0   |1.0        |(1,[],[])     |0.0          |(1,[0],[1.0])   |[0.0,1.0,0.0,57.0,15.0,5.0,18.0,6.0,5.0]|
|0.0    |male  |22.0|0.75        |no      |2.0          |17.0     |6.0       |3.0   |1.0        |(1,[],[])     |1.0          |(1,[],[])       |[0.0,0.0,0.0,22.0,0.75,2.0,17.0,6.0,3.0]|
|0.0    |female|32.0|1.5         |no      |2.0          |17.0     |5.0       |5.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,32.0,1.5,2.0,17.0,5.0,5.0] |
|0.0    |female|22.0|0.75        |no      |2.0          |12.0     |1.0       |3.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,22.0,0.75,2.0,12.0,1.0,3.0]|
|0.0    |male  |57.0|15.0        |yes     |2.0          |14.0     |4.0       |4.0   |1.0        |(1,[],[])     |0.0          |(1,[0],[1.0])   |[0.0,1.0,0.0,57.0,15.0,2.0,14.0,4.0,4.0]|
|0.0    |female|32.0|15.0        |yes     |4.0          |16.0     |1.0       |2.0   |0.0        |(1,[0],[1.0]) |0.0          |(1,[0],[1.0])   |[1.0,1.0,0.0,32.0,15.0,4.0,16.0,1.0,2.0]|
|0.0    |male  |22.0|1.5         |no      |4.0          |14.0     |4.0       |5.0   |1.0        |(1,[],[])     |1.0          |(1,[],[])       |[0.0,0.0,0.0,22.0,1.5,4.0,14.0,4.0,5.0] |
|0.0    |male  |37.0|15.0        |yes     |2.0          |20.0     |7.0       |2.0   |1.0        |(1,[],[])     |0.0          |(1,[0],[1.0])   |[0.0,1.0,0.0,37.0,15.0,2.0,20.0,7.0,2.0]|
|0.0    |male  |27.0|4.0         |yes     |4.0          |18.0     |6.0       |4.0   |1.0        |(1,[],[])     |0.0          |(1,[0],[1.0])   |[0.0,1.0,0.0,27.0,4.0,4.0,18.0,6.0,4.0] |
|0.0    |male  |47.0|15.0        |yes     |5.0          |17.0     |6.0       |4.0   |1.0        |(1,[],[])     |0.0          |(1,[0],[1.0])   |[0.0,1.0,0.0,47.0,15.0,5.0,17.0,6.0,4.0]|
|0.0    |female|22.0|1.5         |no      |2.0          |17.0     |5.0       |4.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,22.0,1.5,2.0,17.0,5.0,4.0] |
|0.0    |female|27.0|4.0         |no      |4.0          |14.0     |5.0       |4.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,27.0,4.0,4.0,14.0,5.0,4.0] |
|0.0    |female|37.0|15.0        |yes     |1.0          |17.0     |5.0       |5.0   |0.0        |(1,[0],[1.0]) |0.0          |(1,[0],[1.0])   |[1.0,1.0,0.0,37.0,15.0,1.0,17.0,5.0,5.0]|
|0.0    |female|37.0|15.0        |yes     |2.0          |18.0     |4.0       |3.0   |0.0        |(1,[0],[1.0]) |0.0          |(1,[0],[1.0])   |[1.0,1.0,0.0,37.0,15.0,2.0,18.0,4.0,3.0]|
|0.0    |female|22.0|0.75        |no      |3.0          |16.0     |5.0       |4.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,22.0,0.75,3.0,16.0,5.0,4.0]|
|0.0    |female|22.0|1.5         |no      |2.0          |16.0     |5.0       |5.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,22.0,1.5,2.0,16.0,5.0,5.0] |
|0.0    |female|27.0|10.0        |yes     |2.0          |14.0     |1.0       |5.0   |0.0        |(1,[0],[1.0]) |0.0          |(1,[0],[1.0])   |[1.0,1.0,0.0,27.0,10.0,2.0,14.0,1.0,5.0]|
+-------+------+----+------------+--------+-------------+---------+----------+------+-----------+--------------+-------------+----------------+----------------------------------------+

2.6 訓練和評估模型

    /**隨機分割測試集和訓練集數據,指定seed可以固定數據分配*/
    val Array(trainingDF, testDF) = dataset.randomSplit(Array(0.6, 0.4), seed = 12345)
    println(s"trainingDF size=${trainingDF.count()},testDF size=${testDF.count()}")
    val lrModel = new LogisticRegression().setLabelCol("affairs").setFeaturesCol("features").fit(trainingDF)
    val predictions = lrModel.transform(testDF).select($"affairs".as("label"), $"features", $"rawPrediction", $"probability", $"prediction")
    predictions.show(false)
    /**使用BinaryClassificationEvaluator來評價我們的模型。在metricName參數中設置度量。*/
    val evaluator = new BinaryClassificationEvaluator()
    evaluator.setMetricName("areaUnderROC")
    val auc= evaluator.evaluate(predictions)
    println(s"areaUnderROC=$auc")

使用model 預測後的數據如下圖所示:

+-----+-----------------------------------------+----------------------------------------+-------------------------------------------+----------+
|label|features                                 |rawPrediction                           |probability                                |prediction|
+-----+-----------------------------------------+----------------------------------------+-------------------------------------------+----------+
|0.0  |[1.0,0.0,0.0,22.0,0.125,4.0,14.0,4.0,5.0]|[24.24907721362884,-24.24907721362884]  |[0.999999999970572,2.942792055040055E-11]  |0.0       |
|0.0  |[1.0,0.0,0.0,22.0,0.417,1.0,17.0,6.0,4.0]|[21.290119589459323,-21.290119589459323]|[0.9999999994326925,5.673075233382041E-10] |0.0       |
|0.0  |[1.0,0.0,0.0,22.0,0.417,5.0,14.0,1.0,4.0]|[24.17979109657276,-24.17979109657276]  |[0.9999999999684608,3.1539162239002745E-11]|0.0       |
|0.0  |[1.0,1.0,0.0,22.0,0.417,3.0,14.0,3.0,5.0]|[22.67775610810491,-22.67775610810491]  |[0.9999999998583633,1.4163665456478983E-10]|0.0       |
|0.0  |[1.0,0.0,0.0,22.0,0.75,2.0,12.0,1.0,3.0] |[18.511403509878832,-18.511403509878832]|[0.9999999908672915,9.13270857267764E-9]   |0.0       |
|0.0  |[1.0,0.0,0.0,22.0,0.75,4.0,16.0,1.0,5.0] |[25.35929557565844,-25.35929557565844]  |[0.999999999990304,9.69611742832185E-12]   |0.0       |
|0.0  |[1.0,0.0,0.0,22.0,0.75,5.0,14.0,3.0,5.0] |[25.260012900022847,-25.260012900022847]|[0.9999999999892919,1.070818300382037E-11] |0.0       |
|0.0  |[1.0,0.0,0.0,22.0,0.75,5.0,18.0,1.0,5.0] |[27.56176640273893,-27.56176640273893]  |[0.9999999999989282,1.0717091528412073E-12]|0.0       |
|0.0  |[1.0,0.0,0.0,22.0,1.5,2.0,14.0,4.0,5.0]  |[21.806773356131036,-21.806773356131036]|[0.9999999996615936,3.3840647423836113E-10]|0.0       |
|0.0  |[1.0,0.0,0.0,22.0,1.5,2.0,16.0,5.0,5.0]  |[22.87962909201085,-22.87962909201085]  |[0.9999999998842548,1.1574529263994485E-10]|0.0       |
|0.0  |[1.0,0.0,0.0,22.0,1.5,2.0,16.0,5.0,5.0]  |[22.87962909201085,-22.87962909201085]  |[0.9999999998842548,1.1574529263994485E-10]|0.0       |
|0.0  |[1.0,0.0,0.0,22.0,1.5,4.0,16.0,5.0,3.0]  |[22.617887847315348,-22.617887847315348]|[0.9999999998496247,1.5037516453560028E-10]|0.0       |
|0.0  |[1.0,1.0,0.0,22.0,1.5,3.0,16.0,5.0,5.0]  |[23.505953663596607,-23.505953663596607]|[0.9999999999381279,6.187198251529256E-11] |0.0       |
|0.0  |[1.0,0.0,0.0,22.0,4.0,4.0,17.0,5.0,5.0]  |[25.142053761516753,-25.142053761516753]|[0.9999999999879512,1.2048827525325212E-11]|0.0       |
|0.0  |[1.0,0.0,0.0,27.0,1.5,2.0,16.0,6.0,5.0]  |[23.342953469838886,-23.342953469838886]|[0.9999999999271745,7.282560759398736E-11] |0.0       |
|0.0  |[1.0,0.0,0.0,27.0,1.5,2.0,18.0,6.0,5.0]  |[24.454819713457812,-24.454819713457812]|[0.9999999999760445,2.3955582882827004E-11]|0.0       |
|0.0  |[1.0,0.0,0.0,27.0,1.5,3.0,18.0,5.0,2.0]  |[21.920009187230548,-21.920009187230548]|[0.9999999996978233,3.021766947986581E-10] |0.0       |
|0.0  |[1.0,0.0,0.0,27.0,4.0,2.0,18.0,5.0,5.0]  |[24.01911260197023,-24.01911260197023]  |[0.9999999999629634,3.703667040712842E-11] |0.0       |
|0.0  |[1.0,0.0,0.0,27.0,4.0,3.0,16.0,5.0,4.0]  |[22.776375736003562,-22.776375736003562]|[0.9999999998716649,1.2833517289922962E-10]|0.0       |
|0.0  |[1.0,1.0,0.0,27.0,4.0,2.0,18.0,6.0,1.0]  |[18.629921259118063,-18.629921259118063]|[0.999999991887999,8.112000996701378E-9]   |0.0       |
+-----+-----------------------------------------+----------------------------------------+-------------------------------------------+----------+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章