需求:一個score表,就一個字段score,求其avg
1>建表語句以及其數據
create table test_score(
score bigint
);
10
30
25
25
9
2>具體實現代碼
package com.hnxy.function;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryStruct;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
public class MySumAvgUDAF extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {
if(info.length != 1){
throw new UDFArgumentException("參數必須是一個!");
}
if(!info[0].getCategory().equals(ObjectInspector.Category.PRIMITIVE)){
throw new UDFArgumentException("參數必須是PRIMITIVE類型!");
}
PrimitiveTypeInfo p = (PrimitiveTypeInfo)info[0];
if(!p.getPrimitiveCategory().equals(PrimitiveObjectInspector.PrimitiveCategory.LONG)){
throw new UDFArgumentException("參數必須Long類型!");
}
return new MyUDAFEvaluator();
}
private static class MyUDAFEvaluator extends GenericUDAFEvaluator{
private Object[] outKey = {new LongWritable(),new LongWritable()};
private Text outValue = new Text();
private static class MyAgg implements AggregationBuffer{
private Long sum = 0L;
private Long count = 0L;
public Long getSum() {
return sum;
}
public void setSum(Long sum) {
this.sum = sum;
}
public Long getCount() {
return count;
}
public void setCount(Long count) {
this.count = count;
}
}
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);
if(m.equals(Mode.PARTIAL1) || m.equals(Mode.PARTIAL2)){
List<String> structFieldNames = new ArrayList<>();
List<ObjectInspector> structFieldTypes = new ArrayList<>();
structFieldNames.add("sum");
structFieldNames.add("count");
structFieldTypes.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
structFieldTypes.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(structFieldNames,structFieldTypes);
}
return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
return new MyAgg();
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
MyAgg ag = (MyAgg)agg;
ag.setSum(0L);
ag.setCount(0L);
}
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
MyAgg ag = (MyAgg)agg;
ag.setSum(ag.getSum() + Long.parseLong(parameters[0].toString()));
ag.setCount(ag.getCount() + 1);
printStr("iterate" +"----->"+ ((MyAgg) agg).getSum() +" : "+((MyAgg)agg).getCount());
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
printStr("terminatePartial" +"----->"+ ((MyAgg) agg).getSum() +" : "+((MyAgg)agg).getCount());
MyAgg ag = (MyAgg)agg;
((LongWritable)outKey[0]).set(ag.getSum());
((LongWritable)outKey[1]).set(ag.getCount());
return outKey;
}
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
printStr("merge" +"----->"+ ((MyAgg) agg).getSum() +" : "+((MyAgg)agg).getCount());
MyAgg ag = (MyAgg)agg;
LongWritable sum = null;
LongWritable count = null;
if(partial instanceof LazyBinaryStruct){
LazyBinaryStruct ls = (LazyBinaryStruct)partial;
sum = (LongWritable) ls.getField(0);
// ag.setSum(ag.getSum() + Long.parseLong(sum + ""));
ag.setSum(Long.parseLong(sum + ""));
count = (LongWritable) ls.getField(1);
// ag.setCount(ag.getCount() + Long.parseLong(count + ""));
ag.setCount(Long.parseLong(count + ""));
}
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
MyAgg ag = (MyAgg)agg;
Double d = Double.parseDouble (ag.getSum() / ag.getCount() + "");
DecimalFormat decimalFormat = new DecimalFormat("###,###.0");
outValue.set( decimalFormat.format(d));
return outValue;
}
public void printStr(String str){
System.out.println("-----------------" +str+ "-----------------");
}
}
}
此處需要在本地配置hadoop和hive環境,我的是idea2017.3
hive運行:
create temporary function myfunc as 'com.hnxy.function.MySumAvgUDAF';
select myfunc(score) score from test_score;
結果:
score
19.0