package com.facebook.hive.udf; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.BooleanWritable; import java.util.ArrayList; import java.util.List; /** * Aggregate the values which satisfy the condition into an array. This * function is similar to COLLECT, except that it takes a second argument; only * rows which satisfy the condition in this second argument will be collected * and the rest will be ignored. Rows with NULL values or NULL conditions will * be ignored. Like COLLECT, you may need to turn off map-side aggregation * lest you exhaust the heap. */ @Description(name = "collect_where", value = "_FUNC_(value, condition) - aggregate the values which satisfy the condition into an array") public class UDAFCollectWhere extends AbstractGenericUDAFResolver { @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { return new Evaluator(); } public static class State implements AggregationBuffer { ArrayList<Object> elements = new ArrayList<Object>(); } public static class Evaluator extends GenericUDAFEvaluator { ObjectInspector inputOI; ListObjectInspector internalMergeOI; ObjectInspector conditionOI; @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { super.init(m, parameters); if (m == Mode.COMPLETE || m == Mode.PARTIAL1) { inputOI = parameters[0]; conditionOI = parameters[1]; return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(inputOI)); } else { internalMergeOI = (ListObjectInspector) parameters[0]; return ObjectInspectorUtils.getStandardObjectInspector(parameters[0]); } } @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { return new State(); } @Override public void iterate(AggregationBuffer agg, Object[] input) throws HiveException { if (input[0] != null && input[1] != null) { BooleanWritable condition = (BooleanWritable)ObjectInspectorUtils.copyToStandardObject(input[1], conditionOI); if (condition.get()) { State state = (State) agg; state.elements.add(ObjectInspectorUtils.copyToStandardObject(input[0], inputOI)); } } } @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null) { State state = (State) agg; state.elements.addAll((List<?>)ObjectInspectorUtils.copyToStandardObject(partial, internalMergeOI)); } } @Override public void reset(AggregationBuffer agg) throws HiveException { ((State) agg).elements.clear(); } @Override public Object terminate(AggregationBuffer agg) throws HiveException { return ((State) agg).elements; } @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { return ((State) agg).elements; } } }