算法小白的第一測嘗試---ID3(Decision Tree)

package DecesionTree
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.feature.StringIndexer
import java.math._
import scala.collection.mutable.ArrayBuffer
/**
 * 基於ID3算法選擇最優特徵---統計學習方法
 */
object ID3 {
  def main(args: Array[String]): Unit = {
    val conf=new SparkConf().setMaster("local").setAppName("ML")
    val sc=new SparkContext(conf)
    val sqlcontext=new SQLContext(sc)
    import sqlcontext.implicits._
    val sampleData=Array(Array("1","青年","否","否","一般","否"),Array("2","青年","否","否","好","否"),Array("3","青年","是","否","好","是"),
                   Array("4","青年","是","是","一般","是"),Array("5","青年","否","否","一般","否"),Array("6","中年","否","否","一般","否"),
                   Array("7","中年","否","否","好","否"),Array("8","中年","是","是","好","是"),Array("9","中年","否","是","非常好","是"),
                   Array("10","中年","否","是","非常好","是"),Array("11","老年","否","是","非常好","是"),Array("12","老年","否","是","好","是"),
                   Array("13","老年","是","否","好","是"),Array("14","老年","是","否","非常好","是"),Array("15","老年","否","否","一般","否"))
    
    val newData=sc.parallelize(sampleData).map { x =>
      val age=x(1)
      val work=x(2)
      val house=x(3)
      val credit=x(4)
      val label=x(5)
      (age,work,house,credit,label)  
    }.persist(StorageLevel.MEMORY_ONLY_SER)
    val DF=newData.toDF("age","isWork","isHouse","credit","label").persist(StorageLevel.MEMORY_ONLY_SER)
    val features=DF.columns
      
    //計算數據集D的熵
    val totalRecord=newData.count()
    val labels=newData.map(x=>(x._5,1)).reduceByKey(_+_).collect()
  
    var Hd=0.0
    for(lab<-labels){
      val labelcount=lab._2.toDouble
      val pi=labelcount/totalRecord
      Hd+= -1.0*((pi)*Math.log(pi)/Math.log(2))
    }
   
   //計算特徵A對數據集的經驗條件熵
   val Hda=ArrayBuffer[Double]()
    for(feature<-features){
      var Hdik=0.0
      if(!"label".equals(feature)){
        //DI表示特徵A對應的信息
        val DI=DF.groupBy(feature).count()
        //lab表示特徵A所有可能得取值
        val Lab=ArrayBuffer[String]()
        val Di=ArrayBuffer[Int]()
        DI.collect().map { Row =>  
          Lab +=Row.getString(Row.fieldIndex(feature))
          Di +=Row.getLong(Row.fieldIndex("count")).toInt
        }    
        //獲取Dik信息
        val Dik=ArrayBuffer[(Int,Int)]()
        for(lab<-Lab){
          var i=0
          val str=s"$feature = " + s"'$lab'"
          println("str:"+str)
          val newDF=DF.where(str).groupBy("label").count()
          val df=newDF.rdd.map { Row => Row.getLong(Row.fieldIndex("count")).toInt}.collect()
          if(newDF.count().toInt ==2)  Dik.append((df(0),df(1))) else Dik.append((df(0),0))
        }
        //計算每個label的條件熵
        for(i<-Di){
          val newDik=Dik.take(1)
          Dik.remove(0,1)
          for(j<-newDik){
            if(j._2 ==0){
              val pi=j._1.toDouble/i
              Hdik += i.toDouble/totalRecord*(-1.0)*(pi)*Math.log(pi)/Math.log(2)
            }else{
              val pi1=j._1.toDouble/i
              val pi2=j._2.toDouble/i
               Hdik += i.toDouble/totalRecord*(-1.0)*(pi1)*Math.log(pi1)/Math.log(2) + i.toDouble/totalRecord*(-1.0)*(pi2)*Math.log(pi2)/Math.log(2) 
            }
          }
        }
        Hda.append(Hdik)
      }
    }
   //Gda表示信息增益,選取信息增益最大值作爲最優特徵。
   val Gda=ArrayBuffer[(String,Double)]()
   for(i<-0 until Hda.length){
     val hda=Hda(i)
     Gda.append((features(i),(Hd-hda)))
   }
   Gda.foreach { x => println(x) } 
  }
}
實驗結果:
(age,0.08300749985576883)
(isWork,0.3236501981515564)
(isHouse,0.419973094021975)
(credit,0.3629895625370855)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章