Hive中UDAF簡單實現

需求:一個score表,就一個字段score,求其avg

1>建表語句以及其數據

create table test_score(
score bigint
);


10
30
25
25
9

2>具體實現代碼

package com.hnxy.function;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryStruct;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;

public class MySumAvgUDAF extends AbstractGenericUDAFResolver {
    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {
        if(info.length != 1){
            throw new UDFArgumentException("參數必須是一個!");
        }

        if(!info[0].getCategory().equals(ObjectInspector.Category.PRIMITIVE)){
            throw new UDFArgumentException("參數必須是PRIMITIVE類型!");
        }

        PrimitiveTypeInfo p = (PrimitiveTypeInfo)info[0];
        if(!p.getPrimitiveCategory().equals(PrimitiveObjectInspector.PrimitiveCategory.LONG)){
            throw new UDFArgumentException("參數必須Long類型!");
        }
        return new MyUDAFEvaluator();
    }

    private static class MyUDAFEvaluator extends GenericUDAFEvaluator{

        private Object[] outKey = {new LongWritable(),new LongWritable()};
        private Text outValue = new Text();

        private static class MyAgg implements AggregationBuffer{
            private Long sum = 0L;
            private Long count = 0L;

            public Long getSum() {
                return sum;
            }
            public void setSum(Long sum) {
                this.sum = sum;
            }
            public Long getCount() {
                return count;
            }
            public void setCount(Long count) {
                this.count = count;
            }
        }

        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
            super.init(m, parameters);

            if(m.equals(Mode.PARTIAL1) || m.equals(Mode.PARTIAL2)){
                List<String> structFieldNames = new ArrayList<>();
                List<ObjectInspector> structFieldTypes = new ArrayList<>();
                structFieldNames.add("sum");
                structFieldNames.add("count");

                structFieldTypes.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
                structFieldTypes.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);

                return ObjectInspectorFactory.getStandardStructObjectInspector(structFieldNames,structFieldTypes);
            }
            return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
        }

        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            return new MyAgg();
        }

        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            MyAgg ag = (MyAgg)agg;
            ag.setSum(0L);
            ag.setCount(0L);
        }

        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
            MyAgg ag = (MyAgg)agg;
            ag.setSum(ag.getSum() + Long.parseLong(parameters[0].toString()));
            ag.setCount(ag.getCount() + 1);
            printStr("iterate" +"----->"+ ((MyAgg) agg).getSum() +" : "+((MyAgg)agg).getCount());
        }

        @Override
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            printStr("terminatePartial" +"----->"+ ((MyAgg) agg).getSum() +" : "+((MyAgg)agg).getCount());
            MyAgg ag = (MyAgg)agg;
            ((LongWritable)outKey[0]).set(ag.getSum());
            ((LongWritable)outKey[1]).set(ag.getCount());
            return outKey;
        }

        @Override
        public void merge(AggregationBuffer agg, Object partial) throws HiveException {
            printStr("merge" +"----->"+ ((MyAgg) agg).getSum() +" : "+((MyAgg)agg).getCount());
            MyAgg ag = (MyAgg)agg;

            LongWritable sum = null;
            LongWritable count = null;
            if(partial instanceof LazyBinaryStruct){
                LazyBinaryStruct ls = (LazyBinaryStruct)partial;
                sum = (LongWritable) ls.getField(0);
//                ag.setSum(ag.getSum() + Long.parseLong(sum + ""));
                ag.setSum(Long.parseLong(sum + ""));
                count = (LongWritable) ls.getField(1);
//                ag.setCount(ag.getCount() + Long.parseLong(count + ""));
                ag.setCount(Long.parseLong(count + ""));
            }
        }

        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            MyAgg ag = (MyAgg)agg;
            Double d = Double.parseDouble (ag.getSum() / ag.getCount() + "");
            DecimalFormat decimalFormat = new DecimalFormat("###,###.0");
            outValue.set( decimalFormat.format(d));
            return outValue;
        }

        public void printStr(String str){
            System.out.println("-----------------" +str+ "-----------------");
        }
    }
}

此處需要在本地配置hadoop和hive環境,我的是idea2017.3

hive運行:

create temporary function myfunc as 'com.hnxy.function.MySumAvgUDAF';
select myfunc(score) score from test_score;

結果:

score
19.0

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章