import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import scala.collection.mutable.ArrayBuffer
/**
* @author XiaoTangBao
* @date 2019/3/10 16:00
* @version 1.0
* 基於統計學習方法--李航 例8.2 提升樹(迴歸)
*/
object BoostingTree {
def main(args: Array[String]): Unit = {
//初始化訓練數據
var orignaldata = ArrayBuffer[LabeledPoint]()
val arr = Array((1,5.56),(2,5.70),(3,5.91),(4,6.40),(5,6.80),(6,7.05),(7,8.90),(8,8.70),(9,9.00),(10,9.05))
for(ar <- arr) orignaldata.append(LabeledPoint(ar._2,Vectors.dense(ar._1)))
//定義切割點
val cutpoints = ArrayBuffer[Double]()
for(i <- 0 until arr.length-1) {cutpoints.append(arr(i)._1 + 0.5)}
//定義滿足條件的損失誤差Li,當損失誤差 <= Li時,退出訓練
val Li = 0.2
//存放最終訓練好的模型
val modelArr = ArrayBuffer[Double => Double]()
var flag = true
while(flag){
val bestModel = getTree(orignaldata.toArray,cutpoints.toArray)
modelArr.append(bestModel)
//更新訓練數據data,生成殘差表,供下次迭代使用
for(i<-0 until orignaldata.length ){
val newLabel = orignaldata(i).label - bestModel(orignaldata(i).features(0))
orignaldata(i) = LabeledPoint(newLabel,Vectors.dense(orignaldata(i).features(0)))
}
//根據殘差表計算平方損失誤差lx
var sle = 0.0
for(s <- orignaldata) sle += math.pow(s.label,2)
if(sle <= Li) flag = false
}
//定義最終的模型
val finalModel =(x:Double) => {
var result = 0.0
for(model <- modelArr) result += model(x)
result
}
//準備相關的測試數據
val csdata = Array(1.4,2.5,2.8,3.5,4.0,4.5,5.6,6.5,6.7)
for(cs <- csdata) println(finalModel(cs))
}
def getTree(data:Array[LabeledPoint],cutpoints:Array[Double])= {
//存儲每一次迭代產生的(cutpoint,(ms,c1,c2))
val msArr = ArrayBuffer[(Double,(Double,Double,Double))]()
for (cutpoint <- cutpoints) msArr.append((cutpoint,calms(cutpoint, data)))
val best = msArr.sortBy(x => x._2._1).take(1)(0)
//生成此次迭代的最佳模型更新訓練數據data,生成殘差表
val bestModel =(x:Double) =>{
val cp = best._1
val c1 = best._2._2
val c2 = best._2._3
if(x < cp) c1 else c2
}
bestModel
}
//假設迴歸樹的損失函數爲平方誤差損失函數 f(x) = min[min(yi-c1)**2 + min(yi-c2)**2]
def calms(cutpoint:Double,data:Array[LabeledPoint])={
//計算誤差
var ms = 0.0
//min(yi-c1)**2
var c1 = 0.0
//min(yi-c2)**2
var c2 = 0.0
//c1、c2所對應的數據和
var s1 = 0.0
var s2 = 0.0
//滿足條件c1、c2的樣本點個數
var n1 = 0
var n2 = 0
for(dt <- data) if(dt.features(0) <= cutpoint) {n1 += 1;s1 += dt.label} else {n2 += 1;s2 += dt.label}
c1 = s1 / n1
c2 = s2 / n2
for(dt <- data) if(dt.features(0) <= cutpoint) ms += math.pow((dt.label - c1),2) else ms += math.pow((dt.label - c2),2)
(ms,c1,c2)
}
}
----------------------------------------------------------result---------------------------------------
5.63
5.818310185185186
5.818310185185186
6.551643518518518
6.551643518518518
6.819699074074074
6.819699074074074
8.950162037037037
8.950162037037037
該實驗結果完全等價於統計學習上的提升樹函數
算法小白的第一次嘗試---BoostingTree(手撕提升樹)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.