SparkSQL 之 基於Java實現UDF和UDAF詳解

UDF

源碼:最多傳入參數爲22個

//傳入兩個參數
def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
  val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
  functionRegistry.registerFunction(
    name,
    (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

註冊:spark.udf().register(函數名,函數體,函數輸出類型);

public static void main(String[] args) {
    SparkSession spark = SparkSession
            .builder()
            .appName("SqlDataSource")
            .master("local")
            .getOrCreate();

    //保留兩位小數,四捨五入
     spark.udf().register("twoDecimal", new UDF1<Double, Double>() {
         @Override
         public Double call(Double in) throws Exception {
             BigDecimal b = new BigDecimal(in);
             double res = b.setScale(2,BigDecimal.ROUND_HALF_DOWN).doubleValue();
             return res;
         }

}

UDTF

繼承 extends UserDefinedAggregateFunction,重寫其中8個方法:
下邊的以求平均值爲案例:

public class MyUDAF extends UserDefinedAggregateFunction {
    private StructType inputSchema;
    private StructType bufferSchema;

    public MyUDAF() {
        List<StructField> inputFields = new ArrayList<>();
        inputFields.add(DataTypes.createStructField("inputColumn",DataTypes.DoubleType,true));
        inputSchema = DataTypes.createStructType(inputFields);

        List<StructField> bufferFields = new ArrayList<>();
        bufferFields.add(DataTypes.createStructField("sum",DataTypes.DoubleType,true));
        bufferFields.add(DataTypes.createStructField("count",DataTypes.DoubleType,true));
        bufferSchema = DataTypes.createStructType(bufferFields);
    }

    //1、該聚合函數的輸入參數的數據類型
    public StructType inputSchema() {
        return inputSchema;
    }

    //2、聚合緩衝區中的數據類型.(有序性)
    public StructType bufferSchema() {
        return bufferSchema;
    }

    //3、返回值的數據類型
    public DataType dataType() {
        return DataTypes.DoubleType;
    }

    //4、這個函數是否總是在相同的輸入上返回相同的輸出,一般爲true
    public boolean deterministic() {
        return true;
    }

    //5、初始化給定的聚合緩衝區,在索引值爲0的sum=0;索引值爲1的count=1;
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0,0D);
        buffer.update(1,0D);
    }

    //6、更新
    public void update(MutableAggregationBuffer buffer, Row input) {
        //如果input的索引值爲0的值不爲0
        if(!input.isNullAt(0)){
            double updateSum = buffer.getDouble(0) + input.getDouble(0);
            double updateCount = buffer.getDouble(1) + 1;
            buffer.update(0,updateSum);
            buffer.update(1,updateCount);
        }
    }

    //7、合併兩個聚合緩衝區,並將更新後的緩衝區值存儲回“buffer1”
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        double mergeSum = buffer1.getDouble(0) + buffer2.getDouble(0);
        double mergeCount = buffer1.getDouble(1) + buffer2.getDouble(1);
        buffer1.update(0,mergeSum);
        buffer1.update(1,mergeCount);

    }

    //8、計算出最終結果
    public Double evaluate(Row buffer) {
        return buffer.getDouble(0)/buffer.getDouble(1);
    }
}

調用UDAF,實現

public class RunMyUDAF {
    public static void main(String[] args) {
        SparkSession spark = SparkSession
                .builder()
                .appName("RunMyUDAF")
                .master("local")
                .getOrCreate();

        // Register the function to access it
        spark.udf().register("myAverage", new MyUDAF());

        Dataset<Row> df = spark.read().json("src/main/resources/employees.json");
        df.createOrReplaceTempView("employees");
        df.show();

//        +-------+------+
//        |   name|salary|
//        +-------+------+
//        |Michael|     0|
//        |   Andy|  4537|
//        | Justin|  3500|
//        |  Berta|     0|
//        |Michael|  3000|
//        |   Andy|  4500|
//        | Justin|  3500|
//        |  Berta|  4000|
//        |   Andy|  4500|
//        +-------+------+

        //保留兩位小數,四捨五入
        spark.udf().register("twoDecimal", new UDF1<Double, Double>() {
            @Override
            public Double call(Double in) throws Exception {
                BigDecimal b = new BigDecimal(in);
                double res = b.setScale(2,BigDecimal.ROUND_HALF_DOWN).doubleValue();
                return res;
            }
        }, DataTypes.DoubleType);


        Dataset<Row> result = spark
        .sql("SELECT name,twoDecimal(myAverage(salary)) as avg_salary FROM employees group by name");
        result.show();

//       +-------+--------------+
//       |   name|  avg_salary  |
//       +-------+--------------+
//       |Michael|        1500.0|
//       |   Andy|       4512.33|
//       |  Berta|        2000.0|
//       | Justin|        3500.0|
//       +-------+--------------+

        spark.stop();
    }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章