Spark-udf自定義函數(強類型)

object learn04 {

  def main(args: Array[String]): Unit = {
    //基本配置
    val conf = new SparkConf().setAppName("learn01").setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._
    //創建rdd  -> ds
    val dataRdd = spark.sparkContext.makeRDD(List(1, 2, 3, 4, 5))
    val dataDs= dataRdd.map({
      case (age) => {
        UserBean(age)
      }
    }).toDS()
    //註冊函數並顯示列名
    val avgFun = new MyAgeAvgClassFunction
    val avgColumn = avgFun.toColumn.name("avgFun")
    dataDs.select(avgColumn).show()
  }

}

case class UserBean(age: BigInt)

case class AvgBuffer(sum: BigInt, count: Int)

/**
  * 求平均數avg
  * 1)繼承Aggregator【輸入,緩衝,輸出】
  * 2)實現方法
  */
class MyAgeAvgClassFunction extends Aggregator[UserBean, AvgBuffer, Double] {

  //初始化緩衝值得大小
  override def zero: AvgBuffer = {
    AvgBuffer(0, 0)
  }

  //內部做計算
  override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
    val c = b.sum + a.age
    val d = b.count + 1
    AvgBuffer(c, d)
  }

  //合併分區時計算
  override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
    val a = b1.sum + b2.sum
    val b = b1.count + b2.count
    AvgBuffer(a, b)
  }

  //輸出
  override def finish(reduction: AvgBuffer): Double = {
    reduction.sum.toDouble / reduction.count
  }

  //用戶自定義就用product,其他用scala提供得
  override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product

  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章