UDF
源碼:最多傳入參數爲22個
//傳入兩個參數
def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}
註冊:spark.udf().register(函數名,函數體,函數輸出類型);
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("SqlDataSource")
.master("local")
.getOrCreate();
//保留兩位小數,四捨五入
spark.udf().register("twoDecimal", new UDF1<Double, Double>() {
@Override
public Double call(Double in) throws Exception {
BigDecimal b = new BigDecimal(in);
double res = b.setScale(2,BigDecimal.ROUND_HALF_DOWN).doubleValue();
return res;
}
}
UDTF
繼承 extends UserDefinedAggregateFunction,重寫其中8個方法:
下邊的以求平均值爲案例:
public class MyUDAF extends UserDefinedAggregateFunction {
private StructType inputSchema;
private StructType bufferSchema;
public MyUDAF() {
List<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("inputColumn",DataTypes.DoubleType,true));
inputSchema = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("sum",DataTypes.DoubleType,true));
bufferFields.add(DataTypes.createStructField("count",DataTypes.DoubleType,true));
bufferSchema = DataTypes.createStructType(bufferFields);
}
//1、該聚合函數的輸入參數的數據類型
public StructType inputSchema() {
return inputSchema;
}
//2、聚合緩衝區中的數據類型.(有序性)
public StructType bufferSchema() {
return bufferSchema;
}
//3、返回值的數據類型
public DataType dataType() {
return DataTypes.DoubleType;
}
//4、這個函數是否總是在相同的輸入上返回相同的輸出,一般爲true
public boolean deterministic() {
return true;
}
//5、初始化給定的聚合緩衝區,在索引值爲0的sum=0;索引值爲1的count=1;
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0,0D);
buffer.update(1,0D);
}
//6、更新
public void update(MutableAggregationBuffer buffer, Row input) {
//如果input的索引值爲0的值不爲0
if(!input.isNullAt(0)){
double updateSum = buffer.getDouble(0) + input.getDouble(0);
double updateCount = buffer.getDouble(1) + 1;
buffer.update(0,updateSum);
buffer.update(1,updateCount);
}
}
//7、合併兩個聚合緩衝區,並將更新後的緩衝區值存儲回“buffer1”
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
double mergeSum = buffer1.getDouble(0) + buffer2.getDouble(0);
double mergeCount = buffer1.getDouble(1) + buffer2.getDouble(1);
buffer1.update(0,mergeSum);
buffer1.update(1,mergeCount);
}
//8、計算出最終結果
public Double evaluate(Row buffer) {
return buffer.getDouble(0)/buffer.getDouble(1);
}
}
調用UDAF,實現
public class RunMyUDAF {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("RunMyUDAF")
.master("local")
.getOrCreate();
// Register the function to access it
spark.udf().register("myAverage", new MyUDAF());
Dataset<Row> df = spark.read().json("src/main/resources/employees.json");
df.createOrReplaceTempView("employees");
df.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 0|
// | Andy| 4537|
// | Justin| 3500|
// | Berta| 0|
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// | Andy| 4500|
// +-------+------+
//保留兩位小數,四捨五入
spark.udf().register("twoDecimal", new UDF1<Double, Double>() {
@Override
public Double call(Double in) throws Exception {
BigDecimal b = new BigDecimal(in);
double res = b.setScale(2,BigDecimal.ROUND_HALF_DOWN).doubleValue();
return res;
}
}, DataTypes.DoubleType);
Dataset<Row> result = spark
.sql("SELECT name,twoDecimal(myAverage(salary)) as avg_salary FROM employees group by name");
result.show();
// +-------+--------------+
// | name| avg_salary |
// +-------+--------------+
// |Michael| 1500.0|
// | Andy| 4512.33|
// | Berta| 2000.0|
// | Justin| 3500.0|
// +-------+--------------+
spark.stop();
}
}