算法小白的第一次嘗試---ID3實現決策樹

package DecesionTree
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.feature.StringIndexer
import java.math._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.DataFrame
/**
 * 基於ID3算法生成決策樹---統計學習方法
 */
object ID4Tree {
  def main(args: Array[String]): Unit = {    
    val conf=new SparkConf().setMaster("local").setAppName("ML")
    val sc=new SparkContext(conf)
    val sqlcontext=new SQLContext(sc)
    val t1=System.currentTimeMillis()
    //train數據
    val sampleData=Array(Array("1","青年","否","否","一般","否"),Array("2","青年","否","否","好","否"),Array("3","青年","是","否","好","是"),
                   Array("4","青年","是","是","一般","是"),Array("5","青年","否","否","一般","否"),Array("6","中年","否","否","一般","否"),
                   Array("7","中年","否","否","好","否"),Array("8","中年","是","是","好","是"),Array("9","中年","否","是","非常好","是"),
                   Array("10","中年","否","是","非常好","是"),Array("11","老年","否","是","非常好","是"),Array("12","老年","否","是","好","是"),
                   Array("13","老年","是","否","好","是"),Array("14","老年","是","否","非常好","是"),Array("15","老年","否","否","一般","否"))
   import sqlcontext.implicits._
   var DF=sc.parallelize(sampleData).map { x => 
      val age=x(1)
      val work=x(2)
      val house=x(3)
      val credit=x(4)
      val label=x(5)
      (age,work,house,credit,label)  
    }.toDF("age","isWork","isHouse","credit","label")
    
    //決策樹節點,內部節點(String),其中String表示當前內部節點,count表示該內部節點所對應的葉子節點數
    val internalNode=ArrayBuffer[String]()
    //決策樹節點,葉子節點(str1,str2),其中str1表示判定條件,str2表示葉子節點標記
    val leafNode=ArrayBuffer[(String,String)]()
    //每個內部節點所對應的子節點數
    val countArr=ArrayBuffer[Int]()
    var flag=true
    while(flag){
      val totalRecord=DF.rdd.count().toInt
      val labels=DF.select("label").rdd.map(Row=>Row.getString(Row.fieldIndex("label"))).map { x => (x,1)}.reduceByKey(_+_).collect()
      val featurePoint=Getfeature(DF,totalRecord,labels)(0)._1
      //df表示上一個特徵點對應的所有label
      val df=DF.select(featurePoint).distinct().rdd.map { Row => Row.getString(Row.fieldIndex(featurePoint))}.collect()
      var count=0
      //arr表示該內部節點中非葉子節點的子節點
      var arr=""
      for(lb<-df){
        //根據最優特徵,劃分數據集
        val str=s"$featurePoint =" + s"'$lb'" 
        val newDF=DF.where(str).select("label").distinct().rdd.map { Row => Row.getString(Row.fieldIndex("label"))}.collect()
        val D1=newDF.length
        if(D1==1){
          leafNode.append((lb,newDF(0)))
          count +=1
        }else{
          arr=lb
        }
      }
      internalNode.append(featurePoint)
      //判斷決策樹是否訓練完成,若當前內部節點所對應的葉子節點的個數爲2,則表示訓練結束
      if(count==2) flag=false else{
        val str2=s"$featurePoint =" + s"'$arr'"
        var sk=DF.where(str2).toDF().rdd.map { Row => 
          val ID="0"
          val age=Row.getString(Row.fieldIndex("age"))
          val isWork=Row.getString(Row.fieldIndex("isWork"))
          val isHouse=Row.getString(Row.fieldIndex("isHouse"))
          val credit=Row.getString(Row.fieldIndex("credit"))
          val label=Row.getString(Row.fieldIndex("label"))
          Array(ID,age,isWork,isHouse,credit,label)
        }.collect()
        //此處應該刷新一下DF
        DF=sc.parallelize(sk).map { x => 
          val age=x(1)
          val work=x(2)
          val house=x(3)
          val credit=x(4)
          val label=x(5)
          (age,work,house,credit,label)  
        }.toDF("age","isWork","isHouse","credit","label")    
      }
      countArr.append(count)
    }
    
    val t2=System.currentTimeMillis()
    println("決策樹算法總耗時:"+(t2-t1))
    println("所有的內部節點")
    internalNode.foreach { x => println(x) }
    println("所有的葉子節點")
    leafNode.foreach(println(_))
    println("每個內部節點對應的葉子節點個數")
    countArr.foreach { x => println(x) }
  }
   
  /**
   * 根據ID3算法,求取最優特徵
   */
  def Getfeature(DF:DataFrame,totalRecord:Int,labels:Array[(String,Int)]):ArrayBuffer[(String, Double)]={   
      val features=DF.columns
      //計算數據集D的熵
      var Hd=0.0
      for(lab<-labels){
        val labelcount=lab._2.toDouble
        val pi=labelcount/totalRecord
        Hd+= -1.0*((pi)*Math.log(pi)/Math.log(2))
      }
     //計算特徵A對數據集的經驗條件熵
     val Hda=ArrayBuffer[Double]()
       for(feature<-features){
          var Hdik=0.0
          if(!"label".equals(feature)){
            //DI表示特徵A對應的信息
            val DI=DF.groupBy(feature).count()
            //lab表示特徵A所有可能得取值
            val Lab=ArrayBuffer[String]()
            val Di=ArrayBuffer[Int]()
            DI.collect().map { Row =>  
              Lab +=Row.getString(Row.fieldIndex(feature))
              Di +=Row.getLong(Row.fieldIndex("count")).toInt
            }    
            //獲取Dik信息
            val Dik=ArrayBuffer[(Int,Int)]()
            for(lab<-Lab){
              var i=0
              val str=s"$feature = " + s"'$lab'"
              val newDF=DF.where(str).groupBy("label").count().persist(StorageLevel.MEMORY_ONLY_SER)
              val df=newDF.rdd.map { Row => Row.getLong(Row.fieldIndex("count")).toInt}.collect()
              if(newDF.count().toInt ==2)  Dik.append((df(0),df(1))) else Dik.append((df(0),0))
            }
            //計算每個label的條件熵
            for(i<-Di){
              val newDik=Dik.take(1)
              Dik.remove(0,1)
              for(j<-newDik){
                if(j._2 ==0){
                  val pi=j._1.toDouble/i
                  Hdik += i.toDouble/totalRecord*(-1.0)*(pi)*Math.log(pi)/Math.log(2)
                }else{
                  val pi1=j._1.toDouble/i
                  val pi2=j._2.toDouble/i
                   Hdik += i.toDouble/totalRecord*(-1.0)*(pi1)*Math.log(pi1)/Math.log(2) + i.toDouble/totalRecord*(-1.0)*(pi2)*Math.log(pi2)/Math.log(2) 
                }
              }
            }
            Hda.append(Hdik)
          }
      }
     //Gda表示信息增益,選取信息增益最大值作爲最優特徵。
     val Gda=ArrayBuffer[(String,Double)]()
     for(i<-0 until Hda.length){
       val hda=Hda(i)
       Gda.append((features(i),(Hd-hda)))
     }
     Gda.sortBy(x=>x._2).reverse.take(1)
  }
}
-----------------------------
Result:
所有的內部節點
isHouse
isWork
所有的葉子節點
(,)
(,)
(,)
每個內部節點對應的葉子節點個數
1
2
----------------------------
Decision Tree:
if("是".equals(isHouse))
	"是"
else if("是".equals(isWork))
	"是"
else "否"	
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章