Hive用戶自定義函數UDAF開發

釋義

UDAF是User Defined Aggregation Function的簡稱。UDAF用來進行聚合運算,其輸入是多行數據,輸出一個計算結果。

如何開發

UDAF有兩種實現方式:繼承UDAF;或繼承AbstractGenericUDAFResolver。前一種方式是簡單的方式,但其使用了java的反射機制,因此性能上比後一種方式要低效,因此生產上不建議使用第一種方式。

計算的邏輯設計

繼承AbstractGenericUDAFResolver抽象類,需要實現一個getEvaluator方法,該方法返回一個實例,該實例繼承GnericEvaluator抽象類,GnericEvaluator的實現類纔是進行聚合計算的具體實現類。

GnericEvaluator有以下幾個方法需要開發自己實現:getNewAggregationBuffer、iterate、merge、reset、terminatePartial、terminate。另外需要重寫init方法,該方法在抽象類中已經實現,但其返回值是null,在開發時如果不重新該方法,可能會在調用時報控指針異常。

一般的UDAF函數在hive計算過程中涉及到三個階段,對應方式:PARTIAL1、PARTIAL2、FINAL。

PARTIAL1階段是map階段計算,此階段會調用函數的init、iterate、terminatePartial三個方法。

PARTIAL2階段是map後的combine階段,是部分結果聚合,此階段會調用函數的init、merge、terminatePartial三個方法。

FINAL階段是reduce階段,輸出給hive最終結果,此階段會調用函數的init、merge、terminate三個方法。

一些特殊的UDAF只有Map階段,對應方式:COMPLETE。

COMPLETE只有Map階段,其調用方法init、iterate、terminate三個方法。

方法釋義

getNewAggregationBuffer方法是返回一個AggregationBuffer實例,該實例實現了AggregationBuffer接口,該接口是個空的接口,其方法是根據實際需要自己定義。AggregationBuffer實例是用來緩存中間及最後聚合結果的。

reset方法是重置AggregationBuffer實例。

iterate方法是逐行處理輸入的數據的。

merge是進行計算結果合併的,包括combine階段及reduce階段。

terminatePartial是用來對Map及combine階段的結果進行持久化的,其返回的值類型必須是java的原始數據類型及其封裝類、hadoop的writable實現類、List或Map。

terminate是向hive輸出最終結果,輸出的類型同樣只能是java原始數據類型及其封裝類、Hadoop writable類型、List或Map。

init方法是定義各個階段的輸入類型及輸出類型,其輸入及輸出類型必須是ObjectInspector的實現類。

示例

下面的示例是進行商品的銷售額統計計算爲例。

package org.hive.demo;

import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;

/**
 * @author sunqiang
 * 該示例是演示UDAF聚合函數的使用,示例是是計算商品的銷售額。
 */
public class UDAF1 extends AbstractGenericUDAFResolver{
	/**
	 * 該方法是用來進行輸入參數的校驗工作,及指定自定義函數的具體實現類。
	 */
	@Override
	public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
		TypeInfo[] params = info.getParameters();
		//校驗參數個數,必須傳遞兩個參數進來
		if(params == null || params.length != 2) {
			throw new UDFArgumentException("Two params must be given.");
		}
		
		//校驗第一個參數類型,必須是int,表示商品銷售數量
		if(params[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
			throw new UDFArgumentTypeException(0, params[0].getTypeName() + " is not a primitive type.");
		}
		if(((PrimitiveTypeInfo)params[0]).getPrimitiveCategory() != PrimitiveCategory.INT) {
			throw new UDFArgumentTypeException(0,"Parameter 1 must be int, but " + params[0].getTypeName() +" is found.");
		}
		//校驗第二個參數類型,必須是double,表示商品單價
		if(params[1].getCategory() != ObjectInspector.Category.PRIMITIVE) {
			throw new UDFArgumentTypeException(1,params[1].getTypeName() + " is not a primitive type.");
		}
		if(((PrimitiveTypeInfo)params[1]).getPrimitiveCategory() != PrimitiveCategory.DOUBLE) {
			throw new UDFArgumentTypeException(1,"Parameter 1 must be double, but " + params[1].getTypeName() +" is found.");
		}
		
		return new UDAFEvaluator();
	}
	
	/**
	 * 具體的計算方法的實現類
	 * @author sunqiang
	 *
	 */
	public static class UDAFEvaluator extends GenericUDAFEvaluator {
		
		private PrimitiveObjectInspector inputOI1;
		private PrimitiveObjectInspector inputOI2;
		
		/**
		 * 該類是用來緩存中間聚合結果的。
		 * @author sunqiang
		 *
		 */
		class SumDoubleAgg implements AggregationBuffer {
			private double sum = 0d;
			
			public void add(double value) {
				this.sum += value;
			}
			
			public double getSum() {
				return sum;
			}
		}
		
		/**
		 * 該方法即是返回一箇中間聚合結果的緩存實現類。
		 */
		@Override
		public AggregationBuffer getNewAggregationBuffer() throws HiveException {
			SumDoubleAgg sa = new SumDoubleAgg();
			return sa;
		}
		
		/**
		 * 該方法在多個階段均會調用,
		 * 一定要實現該方法,否則默認的返回時null,hive調用時就會拋空指針異常,
		 * 該方法的返回值是制定各個階段的輸出類型。
		 */
		@Override  
		public ObjectInspector init(Mode m, ObjectInspector[] params) throws HiveException {
			super.init(m, params);
			if(m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
				inputOI1 = (PrimitiveObjectInspector)params[0];
				inputOI2 = (PrimitiveObjectInspector)params[1];
			}else if(m == Mode.PARTIAL2 || m == Mode.FINAL){
				inputOI1 = (PrimitiveObjectInspector)params[0];
			}
		    return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
		}
		
		/**
		 * 該方法是迭代處理每一行輸入數據。
		 */
		@Override
		public void iterate(AggregationBuffer agg, Object[] params) throws HiveException {
			int p1 = PrimitiveObjectInspectorUtils.getInt(params[0], inputOI1);
			double p2 = PrimitiveObjectInspectorUtils.getDouble(params[1], inputOI2);
			SumDoubleAgg sa = (SumDoubleAgg) agg;
			sa.add(p1*p2);
		}
		
		/**
		 * 該方法是將部分聚合結果進行合併。它是將當前的聚合結果與terminatePartial階段聚合的結果進行合併。
		 */
		@Override
		public void merge(AggregationBuffer agg, Object partial) throws HiveException {
			SumDoubleAgg sa = (SumDoubleAgg)agg;
			double temp = PrimitiveObjectInspectorUtils.getDouble(partial, inputOI1);
			sa.add(temp);
		}
		
		@Override
		public void reset(AggregationBuffer agg) throws HiveException {
			agg = new SumDoubleAgg();
		}
		
		/**
		 * 這個階段即是最終輸出階段,將結果返回給hive。
		 */
		@Override
		public Object terminate(AggregationBuffer agg) throws HiveException {
			SumDoubleAgg sa = (SumDoubleAgg)agg;
			return sa.getSum();
		}
		
		/**
		 * 該方法是用來將當前聚合結果進行持久化,返回的類型必須是java原始數據類型、hadoop writable類型、List、Map幾種
		 */
		@Override
		public Object terminatePartial(AggregationBuffer agg) throws HiveException {
			SumDoubleAgg sa = (SumDoubleAgg)agg;
			return sa.getSum();
		}
		
	}
}

調用

編譯打包後,beeline連接到hive server,使用add jar添加jar包到hive運行環境;

然後創建hive函數create function myudaf as “org.hive.demo.UDAF1”;

最後sql調用創建的函數即可。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章