package org.notmysock.hive; import hyperloglog.HyperLogLog; import hyperloglog.HyperLogLog.EncodingType; import hyperloglog.HyperLogLogUtils; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver2; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableBinaryObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.BytesWritable; @Description(name = "hyperloglog", value = "_FUNC_(x) - generates HyperLogLog structures from input column", extended = "hive> create dau_table as select hyperloglog(uid) as hll_uid from traffic group by dt") public class UDAFHyperLogLog implements GenericUDAFResolver2 { static final class HyperLogLogBuffer extends AbstractAggregationBuffer { public HyperLogLog hll; public HyperLogLogBuffer() { this.reset(); } @Override public int estimate() { return 16*1024; /* 16kb usually */ } public void reset() { hll = HyperLogLog.builder() .setNumRegisterIndexBits(15).setNumHashBits(64).setEncoding(EncodingType.SPARSE).build(); } } public static class HyperLogLogEvaluator extends GenericUDAFEvaluator { ObjectInspector inputOI; WritableBinaryObjectInspector partialOI; ByteArrayOutputStream output = new ByteArrayOutputStream(); /* * All modes returns BINARY columns. * * PARTIAL1 takes in a primitive inspector * * @see org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator#init(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode, org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector[]) */ public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { super.init(m, parameters); partialOI = PrimitiveObjectInspectorFactory.writableBinaryObjectInspector; switch (m) { case PARTIAL1: inputOI = parameters[0]; case PARTIAL2: return partialOI; case FINAL: case COMPLETE: return partialOI; default: throw new IllegalArgumentException("Unknown UDAF mode " + m); } } @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { return new HyperLogLogBuffer(); } @Override public void iterate(AggregationBuffer agg, Object[] args) throws HiveException { if (args[0] == null) { return; } HyperLogLog hll = ((HyperLogLogBuffer)agg).hll; // should use BinarySortableSerDe, perhaps Object val = ObjectInspectorUtils.copyToStandardJavaObject(args[0], inputOI); try { if (val instanceof Byte || val instanceof Character || val instanceof Short) { hll.add(val.hashCode()); } else if (val instanceof Integer) { hll.addInt(((Integer) val).intValue()); } else if(val instanceof Long) { hll.addLong(((Long) val).longValue()); } else if (val instanceof Float) { hll.addFloat(((Float) val).floatValue()); } else if (val instanceof Double) { hll.addDouble((Double)val); } else if (val instanceof String) { hll.addString(val.toString()); } else { /* potential multi-key option */ output.reset(); ObjectOutputStream out = new ObjectOutputStream(output); out.writeObject(val); byte[] key = output.toByteArray(); hll.addBytes(key); } } catch(IOException ioe) { throw new HiveException(ioe); } } @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { HyperLogLog hll = ((HyperLogLogBuffer)agg).hll; output.reset(); try { HyperLogLogUtils.serializeHLL(output, hll); } catch(IOException ioe) { throw new HiveException(ioe); } return new BytesWritable(output.toByteArray()); } @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial == null) { return; } final BytesWritable bw = partialOI.getPrimitiveWritableObject(partial); HyperLogLog hll = ((HyperLogLogBuffer)agg).hll; merge(hll, bw); } protected void merge(HyperLogLog hll, BytesWritable bw) throws HiveException { try { ByteArrayInputStream input = new ByteArrayInputStream(bw.getBytes(), 0, bw.getLength()); HyperLogLog hll2 = HyperLogLogUtils.deserializeHLL(input); hll.merge(hll2); input.close(); } catch (IOException ioe) { throw new HiveException(ioe); } } @Override public void reset(AggregationBuffer agg) throws HiveException { ((HyperLogLogBuffer)agg).reset(); } @Override public Object terminate(AggregationBuffer agg) throws HiveException { HyperLogLog hll = ((HyperLogLogBuffer)agg).hll; output.reset(); try { HyperLogLogUtils.serializeHLL(output, hll); } catch(IOException ioe) { throw new HiveException(ioe); } return new BytesWritable(output.toByteArray()); } } @Override public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException { return getEvaluator(info.getParameters()); } @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { if (parameters.length != 1) { throw new IllegalArgumentException("Function only takes 1 parameter"); } else if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE && parameters[0].getCategory() != ObjectInspector.Category.STRUCT) { throw new UDFArgumentTypeException(1, "Only primitive/struct rows are accepted but " + parameters[0].getTypeName() + " was passed."); } return new HyperLogLogEvaluator(); } }