什麼是UDAF?
UDAF(User Defined Aggregate Function),即用戶定義的聚合函數,聚合函數和普通函數的區別是什麼呢,普通函數是接受一行輸入產生一個輸出,聚合函數是接受一組(一般是多行)輸入然後產生一個輸出,即將一組的值想辦法聚合一下。類似於sum操作,spark的udf使用看這裏
直接看下面的demo,計算1-10的平均值,代碼也比較簡單
package spark
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row, SparkSession}
import java.lang
/**
* spark的UDAF使用
*/
object UDAF {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("UDAFDemo")
.master("local[1]")
.getOrCreate()
val ds: Dataset[lang.Long] = spark.range(1,10)
ds.createTempView("test")
spark.udf.register("jason",new MyUdaf)
spark.sql("select jason(id) as jason from test").show()
}
}
class MyUdaf extends UserDefinedAggregateFunction {
// 定義聚合函數的輸入結構類型
override def inputSchema: StructType = StructType(Array(StructField("age",IntegerType)))
// 聚合緩衝區中值的數據類型
override def bufferSchema: StructType = StructType(Array(StructField("count",IntegerType),StructField("ages",IntegerType)))
// 聚合函數返回的數據類型
override def dataType: DataType = IntegerType
// 聚合函數是否是冪等的,即相同輸入是否總是能得到相同輸出
override def deterministic: Boolean = true
// 初始化緩衝區
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0; buffer(1) = 0
}
// 寫入新數據後更新緩衝區的值
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (input.isNullAt(0)) return
buffer(0) = buffer.getInt(0) + 1
buffer(1) = buffer.getInt(1) + input.getInt(0)
}
// 合併聚合函數緩衝區
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getInt(0)+buffer2.getInt(0)
buffer1(1) = buffer1.getInt(1)+buffer2.getInt(1)
}
// 計算並返回結果
override def evaluate(buffer: Row): Any = buffer.getInt(1)/buffer.getInt(0)
}
熟悉Flink的朋友會發現,這個用法跟Flink的agg函數特別的像,用法幾乎一模一樣.有興趣的可以去研究一下.
運行打印結果:
+-----+
|jason|
+-----+
| 5|
+-----+
如果有寫的不對的地方,歡迎大家指正,如果有什麼疑問,可以加QQ羣:340297350,更多的Flink和spark的乾貨可以加入下面的星球