package brickhouse.udf.timeseries; /** * Copyright 2012 Klout, Inc * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * **/ /** * Similar to Ruby collect, * return an array with all the values */ import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; @Description(name = "union_vector_sum", value = "_FUNC_(x) - Aggregate adding vectors together " ) public class VectorUnionSumUDAF extends AbstractGenericUDAFResolver { /// Snarfed from Hives CollectSet UDAF @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { // TODO Auto-generated method stub if (parameters.length != 1) { throw new UDFArgumentTypeException(parameters.length - 1, "Vector sum takes one argument"); } if (parameters[0].getCategory() == Category.LIST) { return new VectorArraySumUDAFEvaluator(); } else if (parameters[0].getCategory() == Category.MAP) { return new VectorMapSumUDAFEvaluator(); } else { throw new UDFArgumentTypeException(0, " vector_union_sum aggregates either arrays or maps"); } } public static class VectorArraySumUDAFEvaluator extends GenericUDAFEvaluator { // For PARTIAL1 and COMPLETE: ObjectInspectors for original data, an array private ListObjectInspector inputOI; // For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations (list // ( sum of arrays, or arrays) private StandardListObjectInspector stdListOI; static class VectorArrayAggBuffer implements AggregationBuffer { ArrayList<Double> sumArray = new ArrayList<Double>(); } public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { super.init(m, parameters); inputOI = (ListObjectInspector) parameters[0]; if (inputOI.getListElementObjectInspector().getCategory() != Category.PRIMITIVE || !NumericUtil.isNumericCategory( ((PrimitiveObjectInspector) inputOI.getListElementObjectInspector()).getPrimitiveCategory())) { throw new HiveException("Vector values must be numeric."); } /// always return the standard list of doubles stdListOI = ObjectInspectorFactory .getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector); return stdListOI; } @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { AggregationBuffer buff = new VectorArrayAggBuffer(); reset(buff); return buff; } @Override public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { Object p = parameters[0]; if (p != null) { VectorArrayAggBuffer myagg = (VectorArrayAggBuffer) agg; addVector(p, myagg, inputOI); } } @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { VectorArrayAggBuffer myagg = (VectorArrayAggBuffer) agg; addVector(partial, myagg, this.inputOI); } @Override public void reset(AggregationBuffer buff) throws HiveException { VectorArrayAggBuffer arrayBuff = (VectorArrayAggBuffer) buff; arrayBuff.sumArray = new ArrayList<Double>(); } @Override public Object terminate(AggregationBuffer agg) throws HiveException { VectorArrayAggBuffer myagg = (VectorArrayAggBuffer) agg; return myagg.sumArray; } private void addVector(Object listObj, VectorArrayAggBuffer myagg, ListObjectInspector inputOI) { int listLen = inputOI.getListLength(listObj); if (listLen > myagg.sumArray.size()) myagg.sumArray.ensureCapacity(listLen); for (int i = 0; i < listLen; ++i) { Object listElem = inputOI.getListElement(listObj, i); double listElemDbl = NumericUtil.getNumericValue( (PrimitiveObjectInspector) inputOI.getListElementObjectInspector(), listElem); Double oldVal = myagg.sumArray.get(i); if (oldVal != null) { myagg.sumArray.set(i, oldVal + listElemDbl); } else { myagg.sumArray.set(i, listElemDbl); } } } @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { VectorArrayAggBuffer myagg = (VectorArrayAggBuffer) agg; return myagg.sumArray; } } public static class VectorMapSumUDAFEvaluator extends GenericUDAFEvaluator { // For PARTIAL1 and COMPLETE: ObjectInspectors for original data, an array private MapObjectInspector inputOI; // For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations (list // ( sum of arrays, or arrays) private StandardMapObjectInspector stdMapOI; static class VectorMapAggBuffer implements AggregationBuffer { Map<Object, Double> sumMap = new HashMap<Object, Double>(); } public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { super.init(m, parameters); inputOI = (MapObjectInspector) parameters[0]; if (inputOI.getMapKeyObjectInspector().getCategory() != Category.PRIMITIVE) { throw new HiveException("Vector map keys must be a primitive."); } if (inputOI.getMapValueObjectInspector().getCategory() != Category.PRIMITIVE || !NumericUtil.isNumericCategory( ((PrimitiveObjectInspector) inputOI.getMapValueObjectInspector()).getPrimitiveCategory())) { throw new HiveException("Vector values must be numeric."); } stdMapOI = ObjectInspectorFactory. getStandardMapObjectInspector( ObjectInspectorUtils.getStandardObjectInspector(inputOI.getMapKeyObjectInspector(), ObjectInspectorUtils.ObjectInspectorCopyOption.JAVA), PrimitiveObjectInspectorFactory.javaDoubleObjectInspector); //// XXX make return type numeric type of input, //// not doubles... return stdMapOI; } @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { AggregationBuffer buff = new VectorMapAggBuffer(); reset(buff); return buff; } @Override public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { Object p = parameters[0]; if (p != null) { VectorMapAggBuffer myagg = (VectorMapAggBuffer) agg; addVectorMap(p, myagg, inputOI); } } @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { VectorMapAggBuffer myagg = (VectorMapAggBuffer) agg; addVectorMap(partial, myagg, this.inputOI); } @Override public void reset(AggregationBuffer buff) throws HiveException { VectorMapAggBuffer arrayBuff = (VectorMapAggBuffer) buff; arrayBuff.sumMap = new HashMap<Object, Double>(); } @Override public Object terminate(AggregationBuffer agg) throws HiveException { VectorMapAggBuffer myagg = (VectorMapAggBuffer) agg; return myagg.sumMap; } private void addVectorMap(Object mapObj, VectorMapAggBuffer myagg, MapObjectInspector inputOI) { Map uninspMap = inputOI.getMap(mapObj); for (Object uninspKey : uninspMap.keySet()) { Object stdKey = ObjectInspectorUtils.copyToStandardJavaObject(uninspKey, inputOI.getMapKeyObjectInspector()); double stdVal = NumericUtil.getNumericValue((PrimitiveObjectInspector) inputOI.getMapValueObjectInspector(), uninspMap.get(uninspKey)); if (myagg.sumMap.containsKey(stdKey)) { double prevVal = myagg.sumMap.get(stdKey); myagg.sumMap.put(stdKey, prevVal + stdVal); } else { myagg.sumMap.put(stdKey, stdVal); } } } @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { VectorMapAggBuffer myagg = (VectorMapAggBuffer) agg; return myagg.sumMap; } } }