SparkSQL中用戶自定義聚合函數——UDAF

  1. UDAF:用戶自定義聚合函數。
  • 實現UDAF函數如果要自定義類要繼承UserDefinedAggregateFunction類

Java代碼:

SparkConf conf = new SparkConf();
conf.setMaster("local").setAppName("udaf");
JavaSparkContext sc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(sc);
JavaRDD<String> parallelize = sc.parallelize(Arrays.asList("zhansan","lisi","wangwu","zhangsan","zhangsan","lisi"));
JavaRDD<Row> rowRDD = parallelize.map(new Function<String, Row>() {

	/**
	 * 
	 */
	private static final long serialVersionUID = 1L;

	@Override
	public Row call(String s) throws Exception {
              return RowFactory.create(s);
	}
});

List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
df.registerTempTable("user");
/**
 * 註冊一個UDAF函數,實現統計相同值得個數
 * 注意:這裏可以自定義一個類繼承UserDefinedAggregateFunction類也是可以的
 */
sqlContext.udf().register("StringCount", new UserDefinedAggregateFunction() {
	
   /**
    * 
    */
   private static final long serialVersionUID = 1L;
   /**
    * 更新 可以認爲一個一個地將組內的字段值傳遞進來 實現拼接的邏輯
    * buffer.getInt(0)獲取的是上一次聚合後的值
    * 相當於map端的combiner,combiner就是對每一個map task的處理結果進行一次小聚合 
    * 大聚和發生在reduce端.
    * 這裏即是:在進行聚合的時候,每當有新的值進來,對分組後的聚合如何進行計算
    */
   @Override
   public void update(MutableAggregationBuffer buffer, Row arg1) {
         buffer.update(0, buffer.getInt(0)+1);

   }
   /**
    * 合併 update操作,可能是針對一個分組內的部分數據,在某個節點上發生的 但是可能一個分組內的數據,會分佈在多個節點上處理
    * 此時就要用merge操作,將各個節點上分佈式拼接好的串,合併起來
    * buffer1.getInt(0) : 大聚和的時候 上一次聚合後的值       
    * buffer2.getInt(0) : 這次計算傳入進來的update的結果
    * 這裏即是:最後在分佈式節點完成後需要進行全局級別的Merge操作
    */
   @Override
   public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
     buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0));
   }
   /**
    * 指定輸入字段的字段及類型
    */
   @Override
   public StructType inputSchema() {
     return DataTypes.createStructType(
      Arrays.asList(DataTypes.createStructField("name", 
          DataTypes.StringType, true)));
   }
   /**
    * 初始化一個內部的自己定義的值,在Aggregate之前每組數據的初始化結果
    */
   @Override
   public void initialize(MutableAggregationBuffer buffer) {
         buffer.update(0, 0);
   }
   /**
    * 最後返回一個和DataType的類型要一致的類型,返回UDAF最後的計算結果
    */
   @Override
   public Object evaluate(Row row) {
      return row.getInt(0);
   }
   
   @Override
   public boolean deterministic() {
     //設置爲true
     return true;
   }
   /**
    * 指定UDAF函數計算後返回的結果類型
    */
   @Override
   public DataType dataType() {
      return DataTypes.IntegerType;
   }
   /**
    * 在進行聚合操作的時候所要處理的數據的結果的類型
    */
   @Override
   public StructType bufferSchema() {
       return 
       DataTypes.createStructType(
   Arrays.asList(DataTypes.createStructField("bf", DataTypes.IntegerType, 
            true)));
   }
   
});

sqlContext.sql("select name ,StringCount(name) from user group by name").show();

sc.stop();

Scala代碼:

class MyUDAF extends UserDefinedAggregateFunction  {
  // 聚合操作時,所處理的數據的類型
  def bufferSchema: StructType = {
    DataTypes.createStructType(Array(DataTypes.createStructField("aaa", IntegerType, true)))
  }
  // 最終函數返回值的類型
  def dataType: DataType = {
    DataTypes.IntegerType
  }

  def deterministic: Boolean = {
    true
  }
  // 最後返回一個最終的聚合值     要和dataType的類型一一對應
  def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)
  }
  // 爲每個分組的數據執行初始化值
  def initialize(buffer: MutableAggregationBuffer): Unit = {
     buffer(0) = 0
  }
  //輸入數據的類型
  def inputSchema: StructType = {
    DataTypes.createStructType(Array(DataTypes.createStructField("input", StringType, true)))
  }
  // 最後merger的時候,在各個節點上的聚合值,要進行merge,也就是合併
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0)+buffer2.getAs[Int](0) 
  }
  // 每個組,有新的值進來的時候,進行分組對應的聚合值的計算
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0)+1
  }
}

object UDAF {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
    conf.setMaster("local").setAppName("udaf")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    val rdd = sc.makeRDD(Array("zhangsan","lisi","wangwu","zhangsan","lisi"))
    val rowRDD = rdd.map { x => {RowFactory.create(x)} }
    
    val schema = DataTypes.createStructType(Array(DataTypes.createStructField("name", StringType, true)))
    val df = sqlContext.createDataFrame(rowRDD, schema)
    df.show()
    df.registerTempTable("user")
    /**
     * 註冊一個udaf函數
     */
    sqlContext.udf.register("StringCount", new MyUDAF())
    sqlContext.sql("select name ,StringCount(name) from user group by name").show()
    sc.stop()
  }
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章