/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.udf.approx; import java.util.ArrayList; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCount; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCount.GenericUDAFCountEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver2; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; /** * This class implements the Approximate COUNT aggregation function as in SQL. * Implemented using Closed Forms: N(np, n(1-p)p) * We keep track of total elements to calculate p * * Note: Doesn't work with DISTINCT */ @Description(name = "approx_count", value = "_FUNC_() - Returns the total number of retrieved rows, including " + "rows containing NULL values.\n" + "_FUNC_(expr) - Returns the number of rows for which the supplied " + "expression is non-NULL.\n") public class ApproxUDAFCount implements GenericUDAFResolver2 { private static final Log LOG = LogFactory.getLog(GenericUDAFCount.class.getName()); @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { // This method implementation is preserved for backward compatibility. return new GenericUDAFCountEvaluator(); } @Override public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo paramInfo) throws SemanticException { // Adding 2 new parameters: // (1) Table Size for error correction (totalRows) // (2) Sample Size for scaling (sampleRows) TypeInfo[] parameters = paramInfo.getParameters(); assert !paramInfo.isDistinct() : "DISTINCT not supported with APPROX COUNT"; return new ApproxUDAFCountEvaluator().setCountAllColumns( paramInfo.isAllColumns()); } /** * ApproxUDAFCountEvaluator. * */ public static class ApproxUDAFCountEvaluator extends GenericUDAFEvaluator { private boolean countAllColumns = false; private PrimitiveObjectInspector totalRowsOI; private PrimitiveObjectInspector sampleRowsOI; // For PARTIAL1 and COMPLETE // private PrimitiveObjectInspector inputOI; // For PARTIAL2 and FINAL private StructObjectInspector soi; private StructField countField; private StructField totalRowsField; private StructField sampleRowsField; private LongObjectInspector countFieldOI; private LongObjectInspector totalRowsFieldOI; private LongObjectInspector sampleRowsFieldOI; // For PARTIAL1 and PARTIAL2 private Object[] partialResult; // For FINAL and COMPLETE Text result; @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { super.init(m, parameters); if (parameters.length == 2) { totalRowsOI = (PrimitiveObjectInspector) parameters[1]; sampleRowsOI = (PrimitiveObjectInspector) parameters[0]; } else if (parameters.length == 3) { totalRowsOI = (PrimitiveObjectInspector) parameters[2]; sampleRowsOI = (PrimitiveObjectInspector) parameters[1]; } // init input if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) { // sameerag: approx_count doesn't need to read input // inputOI = (PrimitiveObjectInspector) parameters[0]; } else { soi = (StructObjectInspector) parameters[0]; countField = soi.getStructFieldRef("count"); totalRowsField = soi.getStructFieldRef("totalRows"); sampleRowsField = soi.getStructFieldRef("sampleRows"); countFieldOI = (LongObjectInspector) countField .getFieldObjectInspector(); totalRowsFieldOI = (LongObjectInspector) totalRowsField .getFieldObjectInspector(); sampleRowsFieldOI = (LongObjectInspector) sampleRowsField .getFieldObjectInspector(); } // init output if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) { ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>(); foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); ArrayList<String> fname = new ArrayList<String>(); fname.add("count"); fname.add("totalRows"); fname.add("sampleRows"); partialResult = new Object[3]; partialResult[0] = new LongWritable(0); partialResult[1] = new LongWritable(0); partialResult[2] = new LongWritable(0); return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi); } else { result = new Text(); return PrimitiveObjectInspectorFactory.writableStringObjectInspector; } } private ApproxUDAFCountEvaluator setCountAllColumns(boolean countAllCols) { countAllColumns = countAllCols; return this; } /** class for storing count value. */ static class CountAgg implements AggregationBuffer { long value; long totalRows; long sampleRows; } @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { CountAgg buffer = new CountAgg(); reset(buffer); return buffer; } @Override public void reset(AggregationBuffer agg) throws HiveException { ((CountAgg) agg).value = 0; ((CountAgg) agg).totalRows = 0; ((CountAgg) agg).sampleRows = 0; } @Override public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { // parameters == null means the input table/split is empty if (parameters == null) { return; } if (parameters.length == 2) { ((CountAgg) agg).totalRows = PrimitiveObjectInspectorUtils.getLong(parameters[1], totalRowsOI); ((CountAgg) agg).sampleRows = PrimitiveObjectInspectorUtils.getLong(parameters[0], sampleRowsOI); } else { if (parameters.length == 3) { ((CountAgg) agg).totalRows = PrimitiveObjectInspectorUtils.getLong(parameters[2], totalRowsOI); ((CountAgg) agg).sampleRows = PrimitiveObjectInspectorUtils.getLong(parameters[1], sampleRowsOI); } } if (countAllColumns) { // assert parameters.length == 0; ((CountAgg) agg).value++; } else { assert parameters.length > 0; boolean countThisRow = true; for (Object nextParam : parameters) { if (nextParam == null) { countThisRow = false; break; } } if (countThisRow) { ((CountAgg) agg).value++; } } } @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null) { Object partialCount = soi.getStructFieldData(partial, countField); Object partialTotalRows = soi.getStructFieldData(partial, totalRowsField); Object partialSampleRows = soi.getStructFieldData(partial, sampleRowsField); long p = countFieldOI.get(partialCount); long q = totalRowsFieldOI.get(partialTotalRows); long r = sampleRowsFieldOI.get(partialSampleRows); LOG.info("Merge Value: " + p); LOG.info("Merge Total Rows: " + q); LOG.info("Merge Sampling Rows: " + r); ((CountAgg) agg).value += p; ((CountAgg) agg).totalRows = q; ((CountAgg) agg).sampleRows = r; } } @Override public Object terminate(AggregationBuffer agg) throws HiveException { CountAgg myagg = (CountAgg) agg; long approx_count = (long) ((double) (myagg.value * myagg.totalRows) / myagg.sampleRows); double probability = ((double) myagg.value) / ((double) myagg.sampleRows); LOG.info("Value: " + myagg.value); LOG.info("TotalRows: " + myagg.totalRows); LOG.info("Probability: " + probability); LOG.info("Sampling Ratio: " + (((double) myagg.sampleRows) / myagg.totalRows)); StringBuilder sb = new StringBuilder(); sb.append(approx_count); sb.append(" +/- "); sb.append(Math.ceil(2.575 * (((double) myagg.totalRows) / myagg.sampleRows) * Math.sqrt(myagg.value * (1 - probability)))); sb.append(" (99% Confidence) "); result.set(sb.toString()); return result; } @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { CountAgg myagg = (CountAgg) agg; ((LongWritable) partialResult[0]).set(myagg.value); ((LongWritable) partialResult[1]).set(myagg.totalRows); ((LongWritable) partialResult[2]).set(myagg.sampleRows); return partialResult; } } }