Spark 自定義函數(udf,udaf)

Spark 版本 2.3

文中測試數據(json)

{"name":"lillcol", "age":24,"ip":"192.168.0.8"}
{"name":"adson", "age":100,"ip":"192.168.255.1"}
{"name":"wuli", "age":39,"ip":"192.143.255.1"}
{"name":"gu", "age":20,"ip":"192.168.255.1"}
{"name":"ason", "age":15,"ip":"243.168.255.9"}
{"name":"tianba", "age":1,"ip":"108.168.255.1"}
{"name":"clearlove", "age":25,"ip":"222.168.255.110"}
{"name":"clearlove", "age":30,"ip":"222.168.255.110"}

用戶自定義udf

自定義udf的方式有兩種

  1. SQLContext.udf.register()
  2. 創建UserDefinedFunction

這兩種個方式 使用範圍不一樣

package com.test.spark

import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Dataset, Row, SparkSession}

/**
  * @author Administrator
  *         2019/7/22-14:04
  *
  */
object TestUdf {

  val spark = SparkSession
    .builder()
    .appName("TestCreateDataset")
    .config("spark.some.config.option", "some-value")
    .master("local")
    .enableHiveSupport()
    .getOrCreate()
  val sQLContext = spark.sqlContext

  import spark.implicits._


  def main(args: Array[String]): Unit = {
    testudf
  }

  def testudf() = {
    val iptoLong: UserDefinedFunction = getIpToLong()
    val ds: Dataset[Row] = spark.read.json("D:\\DATA-LG\\PUBLIC\\TYGQ\\INF\\testJson")
    ds.createOrReplaceTempView("table1")
    sQLContext.udf.register("addName", sqlUdf(_: String)) //addName 只能在SQL裏面用  不能在DSL 裏面用
    //1.SQL
    sQLContext.sql("select *,addName(name) as nameAddName  from table1")
      .show()
    //2.DSL
    val addName: UserDefinedFunction = udf((str: String) => ("ip: " + str))
    ds.select($"*", addName($"ip").as("ipAddName"))
      .show()

    //如果自定義函數相對複雜,可以將它分離出去 如iptoLong
    ds.select($"*", iptoLong($"ip").as("iptoLong"))
      .show()
  }

  def sqlUdf(name: String): String = {
    "name:" + name
  }

  /**
    * 用戶自定義 UDF 函數
    *
    * @return
    */
  def getIpToLong(): UserDefinedFunction = {
    val ipToLong: UserDefinedFunction = udf((ip: String) => {
      val arr: Array[String] = ip.replace(" ", "").replace("\"", "").split("\\.")
      var result: Long = 0
      var ipl: Long = 0
      if (arr.length == 4) {
        for (i <- 0 to 3) {
          ipl = arr(i).toLong
          result |= ipl << ((3 - i) << 3)
        }
      } else {
        result = -1
      }
      result
    })
    ipToLong
  }


}

輸出結果
+---+---------------+---------+--------------+
|age|             ip|     name|   nameAddName|
+---+---------------+---------+--------------+
| 24|    192.168.0.8|  lillcol|  name:lillcol|
|100|  192.168.255.1|    adson|    name:adson|
| 39|  192.143.255.1|     wuli|     name:wuli|
| 20|  192.168.255.1|       gu|       name:gu|
| 15|  243.168.255.9|     ason|     name:ason|
|  1|  108.168.255.1|   tianba|   name:tianba|
| 25|222.168.255.110|clearlove|name:clearlove|
| 30|222.168.255.110|clearlove|name:clearlove|
+---+---------------+---------+--------------+

+---+---------------+---------+-------------------+
|age|             ip|     name|          ipAddName|
+---+---------------+---------+-------------------+
| 24|    192.168.0.8|  lillcol|    ip: 192.168.0.8|
|100|  192.168.255.1|    adson|  ip: 192.168.255.1|
| 39|  192.143.255.1|     wuli|  ip: 192.143.255.1|
| 20|  192.168.255.1|       gu|  ip: 192.168.255.1|
| 15|  243.168.255.9|     ason|  ip: 243.168.255.9|
|  1|  108.168.255.1|   tianba|  ip: 108.168.255.1|
| 25|222.168.255.110|clearlove|ip: 222.168.255.110|
| 30|222.168.255.110|clearlove|ip: 222.168.255.110|
+---+---------------+---------+-------------------+

+---+---------------+---------+----------+
|age|             ip|     name|  iptoLong|
+---+---------------+---------+----------+
| 24|    192.168.0.8|  lillcol|3232235528|
|100|  192.168.255.1|    adson|3232300801|
| 39|  192.143.255.1|     wuli|3230662401|
| 20|  192.168.255.1|       gu|3232300801|
| 15|  243.168.255.9|     ason|4087938825|
|  1|  108.168.255.1|   tianba|1823014657|
| 25|222.168.255.110|clearlove|3735617390|
| 30|222.168.255.110|clearlove|3735617390|
+---+---------------+---------+----------+

用戶自定義 UDAF 函數(即聚合函數)

弱類型用戶自定義聚合函數

通過繼承UserDefinedAggregateFunction

package com.test.spark

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row, SparkSession}

/**
  * @author lillcol
  *         2019/7/22-15:09
  *         弱類型用戶自定義聚合函數
  */
object TestUDAF extends UserDefinedAggregateFunction {
  // 聚合函數輸入參數的數據類型
  // :: 用於的是向隊列的頭部追加數據,產生新的列表,Nil 是一個空的 List,定義爲 List[Nothing]
  override def inputSchema: StructType = StructType(StructField("age", IntegerType) :: Nil)

  //等效於
  //  override def inputSchema: StructType=new StructType() .add("age", IntegerType).add("name", StringType)

  // 聚合緩衝區中值的數據類型
  override def bufferSchema: StructType = {
    StructType(StructField("sum", IntegerType) :: StructField("count", IntegerType) :: Nil)
  }

  // UserDefinedAggregateFunction返回值的數據類型。
  override def dataType: DataType = DoubleType

  // 如果這個函數是確定的,即給定相同的輸入,總是返回相同的輸出。
  override def deterministic: Boolean = true

  //  初始化給定的聚合緩衝區,即聚合緩衝區的零值。
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    // sum,  總的年齡
    buffer(0) = 0
    // count, 人數
    buffer(1) = 0
  }

  //  使用來自輸入的新輸入數據更新給定的聚合緩衝區。
  // 每個輸入行調用一次。(同一分區)
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getInt(0) + input.getInt(0) //年齡 疊加
    buffer(1) = buffer.getInt(1) + 1 //人數疊加
  }

  //  合併兩個聚合緩衝區並將更新後的緩衝區值存儲回buffer1。
  // 當我們將兩個部分聚合的數據合併在一起時,就會調用這個函數。(多個分區)
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0) //年齡 疊加
    buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1) //人數疊加
  }

  override def evaluate(buffer: Row): Any = {
    buffer.getInt(0).toDouble / buffer.getInt(1)
  }

  val spark = SparkSession
    .builder()
    .appName("Spark SQL basic example")
    // .config("spark.some.config.option", "some-value")
    .master("local[*]") // 本地測試
    .getOrCreate()

  import spark.implicits._

  def main(args: Array[String]): Unit = {
    spark.udf.register("myAvg", TestUDAF)
    val ds: Dataset[Row] = spark.read.json("D:\\DATA-LG\\PUBLIC\\TYGQ\\INF\\testJson")
    ds.createOrReplaceTempView("table1")
    //SQL
    spark.sql("select myAvg(age) as avgAge from table1")
      .show()

    //DSL
    val myavg = TestUDAF
    ds.select(TestUDAF($"age").as("avgAge"))
      .show()
  }
}

輸出結果:
+------+
|avgAge|
+------+
| 31.75|
+------+

+------+
|avgAge|
+------+
| 31.75|
+------+

強類型用戶自定義聚合函數

通過繼承Aggregator(是org.apache.spark.sql.expressions 下的 不要引錯包了)

package com.test.spark

import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions._

/**
  * @author Administrator
  *         2019/7/22-16:07
  *
  */
// 既然是強類型,可能有 case 類
case class Person(name: String, age: Double, ip: String)

case class Average(var sum: Double, var count: Double)

object MyAverage extends Aggregator[Person, Average, Double] {
  //  此聚合的值爲零。應該滿足任意b + 0 = b的性質。
  //  定義一個數據結構,保存工資總數和工資總個數,初始都爲0
  override def zero: Average = {
    Average(0, 0)
  }

  //  將兩個值組合起來生成一個新值。爲了提高性能,函數可以修改b並返回它,而不是爲b構造新的對象。
  //  相同 Execute 間的數據合併(同一分區)
  override def reduce(b: Average, a: Person): Average = {
    b.sum += a.age
    b.count += 1
    b
  }

  // 合併兩個中間值。
  // 聚合不同 Execute 的結果(不同分區)
  override def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }

  // 計算最終結果
  override def finish(reduction: Average): Double = {
    reduction.sum.toInt / reduction.count
  }

  //  爲中間值類型指定“編碼器”。
  override def bufferEncoder: Encoder[Average] = Encoders.product

  //  爲最終輸出值類型指定“編碼器”。
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

  val spark = SparkSession
    .builder()
    .appName("Spark SQL basic example")
    // .config("spark.some.config.option", "some-value")
    .master("local[*]") // 本地測試
    .getOrCreate()

  import spark.implicits._

  def main(args: Array[String]): Unit = {
    val ds: Dataset[Person] = spark.read.json("D:\\DATA-LG\\PUBLIC\\TYGQ\\INF\\testJson").as[Person]
    ds.show()

    val avgAge = MyAverage.toColumn/*.name("avgAge")*///指定該列的別名爲avgAge
    ds.select(avgAge)//執行avgAge.as("columnName") 彙報org.apache.spark.sql.AnalysisException錯誤  別名只能在上面指定(目前測試是這樣)
      .show()
  }
}

輸出結果:
+---+---------------+---------+
|age|             ip|     name|
+---+---------------+---------+
| 24|    192.168.0.8|  lillcol|
|100|  192.168.255.1|    adson|
| 39|  192.143.255.1|     wuli|
| 20|  192.168.255.1|       gu|
| 15|  243.168.255.9|     ason|
|  1|  108.168.255.1|   tianba|
| 25|222.168.255.110|clearlove|
| 30|222.168.255.110|clearlove|
+---+---------------+---------+

+------+
|avgAge|
+------+
| 31.75|
+------+

本文爲原創文章,轉載請註明出處!!!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章