HIVE - UDAF開發(字符串中出現 指定字符的次數,再求次數的平均數)

一、實現展示
hive> desc test_avg_str_in_str;
user_id             	int
name                	string
value               	int

hive> select * from test_avg_str_in_str;
1	awuz	1
1	azhaoz	1
2	zhangsan	2
2	lisi	2
2	wangwu	3

-- UDAF: avgStr (找到name中出現z的次數,再求平均數)
-- 難點在計算平均數的時候,中間結果需要保存 總值和計數值,需要用到 LazyBinaryStruct 結構
hive> select user_id, avgStr(name, "z") from test_avg_str_in_str group by user_id;
1	1.5
2	0.333333

PS. 這個UDAF實現的功能目前自己瞎想的,沒有啥業務應用…

二、關鍵函數

在這裏插入圖片描述

  • PARTIAL1: map階段, 調用iterate()和terminatePartial()
  • PARTIAL2: map端的Combiner階段,調用merge() 和 terminatePartial()
  • FINAL: reduce階段,調用merge()和terminate()
// 確定各個階段輸入輸出參數的數據格式ObjectInspectors
public  ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException;
// 保存數據聚集結果的類
abstract AggregationBuffer getNewAggregationBuffer() throws HiveException;
// 重置聚集結果
public void reset(AggregationBuffer agg) throws HiveException;
// map階段,迭代處理輸入sql傳過來的列數據
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException;
// map與combiner結束返回結果,得到部分數據聚集結果
public Object terminatePartial(AggregationBuffer agg) throws HiveException;
// combiner合併map返回的結果,還有reducer合併mapper或combiner返回的結果。
public void merge(AggregationBuffer agg, Object partial) throws HiveException;
// reducer階段,輸出最終結果
public Object terminate(AggregationBuffer agg) throws HiveException;
三、代碼CODE
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryStruct;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import java.util.ArrayList;
import java.util.List;


public class AvgStrInStrUDAF extends AbstractGenericUDAFResolver {
    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
        if (parameters.length != 2) {
            throw new UDFArgumentTypeException(parameters.length - 1,
                    "Exactly tow argument is expected.");
        }
        return new AvgCharInStringEvaluator();
    }

    public static class AvgCharInStringEvaluator extends GenericUDAFEvaluator {

        private Object[] outKey = {new LongWritable(),new LongWritable()};
        private DoubleWritable result;

        private static class AvgAgg implements AggregationBuffer {
            private Long sum = 0L;
            private Long count = 0L;
            public Long getSum() {
                return sum;
            }
            public void setSum(Long sum) {
                this.sum = sum;
            }
            public Long getCount() {
                return count;
            }
            public void setCount(Long count) {
                this.count = count;
            }
        }

        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            return new AvgAgg();
        }

        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            AvgAgg ag = (AvgAgg) agg;
            ag.setSum(0L);
            ag.setCount(0L);
        }

        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                throws HiveException {
            super.init(m, parameters);
            if (m.equals(Mode.PARTIAL1) || m.equals(Mode.PARTIAL2)) {
                List<String> structFieldNames = new ArrayList();
                List<ObjectInspector> structFieldTypes = new ArrayList();
                structFieldNames.add("sum");
                structFieldNames.add("count");

                structFieldTypes.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
                structFieldTypes.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);

                return ObjectInspectorFactory.getStandardStructObjectInspector(structFieldNames, structFieldTypes);
            }
            result = new DoubleWritable(0.0);
            return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }

        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
            if (parameters == null) {
                return;
            }
            if (parameters[0] != null && parameters[1] != null) {
                String s1 = parameters[0].toString();
                String s2 = parameters[1].toString();
                long count = (s1.length()-s1.replace(s2, "").length())/s2.length();
                AvgAgg ag = (AvgAgg) agg;
                ag.setSum(ag.getSum() + count);
                ag.setCount(ag.getCount() + 1);
            }
        }

        @Override
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            AvgAgg ag = (AvgAgg) agg;
            ((LongWritable) outKey[0]).set(ag.getSum());
            ((LongWritable) outKey[1]).set(ag.getCount());
            return outKey;
        }

        @Override
        public void merge(AggregationBuffer agg, Object partial) throws HiveException {
            if (partial != null) {
                AvgAgg ag = (AvgAgg) agg;

                LongWritable sum = null;
                LongWritable count = null;
                if (partial instanceof LazyBinaryStruct) {
                    LazyBinaryStruct ls = (LazyBinaryStruct) partial;
                    sum = (LongWritable) ls.getField(0);
                    ag.setSum(Long.parseLong(sum + ""));
                    count = (LongWritable) ls.getField(1);
                    ag.setCount(Long.parseLong(count + ""));
                }
            }
        }

        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            AvgAgg ag = (AvgAgg) agg;
            Double d = Double.parseDouble(ag.getSum() * 1.0 / ag.getCount() + "");
            result.set(d);
            return result;
        }
    }
}

參考文章

https://blog.csdn.net/kent7306/article/details/50110067
https://blog.csdn.net/Nougats/article/details/71978752
https://www.jianshu.com/p/7ebc8f9c9b78

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章