package com.facebook.hive.udf;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
/**
* Aggregate all the values of lists into an array. This is like collect_set(col),
* which returns the same thing for columns of a primitive type. Because the
* mappers must keep all of the data in memory, if your data is non-trivially
* large you should set hive.map.aggr=false to ensure that UNION_SET is only
* executed in the reduce phase.
* @author cbueno
*/
// The unitest result should properly be something like:
// ["3.4","3.5","3.6","3.7","3.8","4.4","3.9","4.2","4.1","3.1","3.2","3.3","2.2","2.3","2.4","2","4","3","2.9","2.5","2.6","2.7","2.8"]
// But since Sets can return elements in undefined order, I can't
// guarantee that what I get back will match the static string.
@Description(
name = "union_set",
value = "_FUNC_(col) - aggregate the values of an array column to one array",
extended = "Aggregate the values, return as an ArrayList.")
public class UDAFUnionSet extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
return new Evaluator();
}
public static class State implements AggregationBuffer {
HashSet<Object> set = new HashSet<>();
}
public static class Evaluator extends GenericUDAFEvaluator {
ObjectInspector inputOI;
ListObjectInspector internalMergeOI;
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);
if (m == Mode.COMPLETE || m == Mode.PARTIAL1) {
inputOI = parameters[0];
} else {
internalMergeOI = (ListObjectInspector) parameters[0];
}
return ObjectInspectorUtils.getStandardObjectInspector(parameters[0]);
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
return new State();
}
@Override
public void iterate(AggregationBuffer agg, Object[] input) throws HiveException {
if (input[0] != null) {
State state = (State) agg;
state.set.addAll((List<?>)ObjectInspectorUtils.copyToStandardObject(input[0], inputOI));
}
}
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
if (partial != null) {
State state = (State) agg;
List<?> pset = (List<?>)ObjectInspectorUtils.copyToStandardObject(partial, internalMergeOI);
state.set.addAll(pset);
}
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
((State) agg).set.clear();
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
return new ArrayList<>(((State) agg).set);
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
return new ArrayList<>(((State) agg).set);
}
}
}