package brickhouse.udf.collect; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import java.util.HashMap; import java.util.Map; import static org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardMapObjectInspector; import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption.JAVA; import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.getStandardObjectInspector; import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; public abstract class AbstractCollectMergeUDAF extends AbstractGenericUDAFResolver { public abstract Map<PrimitiveCategory, Class<? extends CollectMergeUDAFEvaluator>> evaluators(); public CollectMergeUDAFEvaluator newEvaluator(PrimitiveCategory valueCategory) { Class<? extends CollectMergeUDAFEvaluator> evaluatorClass = evaluators().get(valueCategory); try { return evaluatorClass.newInstance(); } catch (InstantiationException e) { throw new RuntimeException(e); } catch (IllegalAccessException e) { throw new RuntimeException(e); } } private String supportedTypes() { String res = ""; for (PrimitiveCategory category : evaluators().keySet()) { res += ", " + category.name(); } return res.substring(2); } @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { if (parameters.length != 2) { throw new UDFArgumentTypeException(1, "Expected 2 arguments"); } if (parameters[1] instanceof PrimitiveTypeInfo) { PrimitiveTypeInfo valueInfo = (PrimitiveTypeInfo) parameters[1]; CollectMergeUDAFEvaluator evaluator = newEvaluator(valueInfo.getPrimitiveCategory()); if (evaluator != null) { return evaluator; } else { throw new UDFArgumentTypeException(1, "Only " + supportedTypes() + " types are supported for the 2nd argument"); } } else { throw new UDFArgumentTypeException(1, "2nd argument must be primitive"); } } public static abstract class CollectMergeUDAFEvaluator<E> extends GenericUDAFEvaluator { protected ObjectInspector keyOI; protected PrimitiveObjectInspector valueOI; protected StandardMapObjectInspector internalMergeOI; @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { super.init(m, parameters); if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) { keyOI = parameters[0]; valueOI = (PrimitiveObjectInspector) parameters[1]; } else { internalMergeOI = (StandardMapObjectInspector) parameters[0]; keyOI = internalMergeOI.getMapKeyObjectInspector(); valueOI = (PrimitiveObjectInspector) internalMergeOI.getMapValueObjectInspector(); } return getStandardMapObjectInspector( getStandardObjectInspector(keyOI), getStandardObjectInspector(valueOI, JAVA) ); } @Override public abstract MergeAggBuffer<E> getNewAggregationBuffer() throws HiveException; @Override public void reset(AggregationBuffer agg) throws HiveException { ((MergeAggBuffer) agg).reset(); } @Override public void iterate(AggregationBuffer agg, Object[] args) throws HiveException { merge((MergeAggBuffer) agg, args[0], args[1]); } @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { return terminate(agg); } @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { MergeAggBuffer myAgg = (MergeAggBuffer) agg; for (Map.Entry<?, ?> e : internalMergeOI.getMap(partial).entrySet()) { merge(myAgg, e.getKey(), e.getValue()); } } @Override public Object terminate(AggregationBuffer agg) throws HiveException { return ((MergeAggBuffer) agg).copy(); } @SuppressWarnings("unchecked") private void merge(MergeAggBuffer agg, Object key, Object value) { Object keyCopy = ObjectInspectorUtils.copyToStandardObject(key, keyOI); E primValue = (E) valueOI.getPrimitiveJavaObject(value); ((MergeAggBuffer<E>) agg).merge(keyCopy, primValue); } } public static interface MergeAggBuffer<V> extends AggregationBuffer { void reset(); V mergeValues(V left, V right); void merge(Object key, V value); Map<Object, V> copy(); } public static abstract class HashMapMergeAggBuffer<V> extends HashMap<Object, V> implements MergeAggBuffer<V> { @Override public void reset() { clear(); } @Override public void merge(Object key, V value) { V oldValue = get(key); if (oldValue != null) { // don't merge with null // this method should be overridden in subclasses if special treatment of nulls is needed if (value != null) { V newValue = mergeValues(oldValue, value); if (!oldValue.equals(newValue)) { put(key, newValue); } } } else { // if key is absent in the map yet of previous value is null put(key, value); } } @Override public Map<Object, V> copy() { return new HashMap<Object, V>(this); } } }