一、實現展示
hive> desc test_avg_str_in_str;
user_id int
name string
value int
hive> select * from test_avg_str_in_str;
1 awuz 1
1 azhaoz 1
2 zhangsan 2
2 lisi 2
2 wangwu 3
-- UDAF: avgStr (找到name中出現z的次數,再求平均數)
-- 難點在計算平均數的時候,中間結果需要保存 總值和計數值,需要用到 LazyBinaryStruct 結構
hive> select user_id, avgStr(name, "z") from test_avg_str_in_str group by user_id;
1 1.5
2 0.333333
PS. 這個UDAF實現的功能目前自己瞎想的,沒有啥業務應用…
二、關鍵函數
- PARTIAL1: map階段, 調用iterate()和terminatePartial()
- PARTIAL2: map端的Combiner階段,調用merge() 和 terminatePartial()
- FINAL: reduce階段,調用merge()和terminate()
// 確定各個階段輸入輸出參數的數據格式ObjectInspectors
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException;
// 保存數據聚集結果的類
abstract AggregationBuffer getNewAggregationBuffer() throws HiveException;
// 重置聚集結果
public void reset(AggregationBuffer agg) throws HiveException;
// map階段,迭代處理輸入sql傳過來的列數據
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException;
// map與combiner結束返回結果,得到部分數據聚集結果
public Object terminatePartial(AggregationBuffer agg) throws HiveException;
// combiner合併map返回的結果,還有reducer合併mapper或combiner返回的結果。
public void merge(AggregationBuffer agg, Object partial) throws HiveException;
// reducer階段,輸出最終結果
public Object terminate(AggregationBuffer agg) throws HiveException;
三、代碼CODE
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryStruct;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import java.util.ArrayList;
import java.util.List;
public class AvgStrInStrUDAF extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
if (parameters.length != 2) {
throw new UDFArgumentTypeException(parameters.length - 1,
"Exactly tow argument is expected.");
}
return new AvgCharInStringEvaluator();
}
public static class AvgCharInStringEvaluator extends GenericUDAFEvaluator {
private Object[] outKey = {new LongWritable(),new LongWritable()};
private DoubleWritable result;
private static class AvgAgg implements AggregationBuffer {
private Long sum = 0L;
private Long count = 0L;
public Long getSum() {
return sum;
}
public void setSum(Long sum) {
this.sum = sum;
}
public Long getCount() {
return count;
}
public void setCount(Long count) {
this.count = count;
}
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
return new AvgAgg();
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
AvgAgg ag = (AvgAgg) agg;
ag.setSum(0L);
ag.setCount(0L);
}
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
super.init(m, parameters);
if (m.equals(Mode.PARTIAL1) || m.equals(Mode.PARTIAL2)) {
List<String> structFieldNames = new ArrayList();
List<ObjectInspector> structFieldTypes = new ArrayList();
structFieldNames.add("sum");
structFieldNames.add("count");
structFieldTypes.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
structFieldTypes.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(structFieldNames, structFieldTypes);
}
result = new DoubleWritable(0.0);
return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
}
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
if (parameters == null) {
return;
}
if (parameters[0] != null && parameters[1] != null) {
String s1 = parameters[0].toString();
String s2 = parameters[1].toString();
long count = (s1.length()-s1.replace(s2, "").length())/s2.length();
AvgAgg ag = (AvgAgg) agg;
ag.setSum(ag.getSum() + count);
ag.setCount(ag.getCount() + 1);
}
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
AvgAgg ag = (AvgAgg) agg;
((LongWritable) outKey[0]).set(ag.getSum());
((LongWritable) outKey[1]).set(ag.getCount());
return outKey;
}
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
if (partial != null) {
AvgAgg ag = (AvgAgg) agg;
LongWritable sum = null;
LongWritable count = null;
if (partial instanceof LazyBinaryStruct) {
LazyBinaryStruct ls = (LazyBinaryStruct) partial;
sum = (LongWritable) ls.getField(0);
ag.setSum(Long.parseLong(sum + ""));
count = (LongWritable) ls.getField(1);
ag.setCount(Long.parseLong(count + ""));
}
}
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
AvgAgg ag = (AvgAgg) agg;
Double d = Double.parseDouble(ag.getSum() * 1.0 / ag.getCount() + "");
result.set(d);
return result;
}
}
}
參考文章
https://blog.csdn.net/kent7306/article/details/50110067
https://blog.csdn.net/Nougats/article/details/71978752
https://www.jianshu.com/p/7ebc8f9c9b78