package edu.washington.escience.myria.operator.agg; import java.util.ArrayList; import java.util.List; import java.util.Objects; import javax.annotation.Nullable; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.column.Column; import edu.washington.escience.myria.expression.Expression; import edu.washington.escience.myria.expression.evaluate.ExpressionOperatorParameter; import edu.washington.escience.myria.expression.evaluate.GenericEvaluator; import edu.washington.escience.myria.expression.evaluate.PythonUDFEvaluator; import edu.washington.escience.myria.functions.PythonFunctionRegistrar; import edu.washington.escience.myria.operator.Operator; import edu.washington.escience.myria.operator.UnaryOperator; import edu.washington.escience.myria.operator.UniqueTupleHashTable; import edu.washington.escience.myria.operator.agg.PrimitiveAggregator.AggregationOp; import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.storage.TupleBatchBuffer; import edu.washington.escience.myria.util.MyriaArrayUtils; /** * The Aggregation operator that computes an aggregate (e.g., sum, avg, max, min). This variant supports aggregates over * multiple columns, group by multiple columns. */ public class Aggregate extends UnaryOperator { /** Java requires this. **/ private static final long serialVersionUID = 1L; /** The hash table containing groups and states. */ protected transient UniqueTupleHashTable groupStates; /** Factories to make the Aggregators. **/ private final AggregatorFactory[] factories; /** Aggregators of the internal state. */ protected List<Aggregator> internalAggs; /** Expressions that emit output. */ protected List<GenericEvaluator> emitEvals; /** Group fields. Empty array means no grouping. **/ protected final int[] gfields; /** Buffer for restoring results. */ protected TupleBatchBuffer resultBuffer; /** * Groups the input tuples according to the specified grouping fields, then produces the specified aggregates. * * @param child The Operator that is feeding us tuples. * @param gfields The columns over which we are grouping the result. Null means no group by. * @param factories The factories that will produce the {@link Aggregator}s; */ public Aggregate( @Nullable final Operator child, final int[] gfields, final AggregatorFactory... factories) { super(child); this.gfields = gfields; this.factories = Objects.requireNonNull(factories, "factories"); } @Override protected void cleanup() throws DbException { groupStates.cleanup(); resultBuffer.clear(); } /** * Returns the next tuple. The first few columns are group-by fields if there are any, followed by columns of * aggregate results generated by {@link Aggregate#emitEvals}. * * @throws DbException if any error occurs. * @return result TB. */ @Override protected TupleBatch fetchNextReady() throws DbException { final Operator child = getChild(); TupleBatch tb = child.nextReady(); while (tb != null) { for (int row = 0; row < tb.numTuples(); ++row) { int index = groupStates.getIndex(tb, gfields, row); if (index == -1) { groupStates.addTuple(tb, gfields, row, true); int offset = gfields.length; for (Aggregator agg : internalAggs) { agg.initState(groupStates.getData(), offset); offset += agg.getStateSize(); } index = groupStates.numTuples() - 1; } int offset = gfields.length; for (Aggregator agg : internalAggs) { agg.addRow(tb, row, groupStates.getData(), index, offset); offset += agg.getStateSize(); } } tb = child.nextReady(); } if (child.eos()) { /* Special check for count(*) as the only aggregate on an empty relation: emit 0. */ if (getNumOutputTuples() == 0 && groupStates.numTuples() == 0 && isCountAllOnlyAggregate()) { resultBuffer.putLong(0, 0); } generateResult(); return resultBuffer.popAny(); } return null; } /** * Check if count(*) is the only aggregate with no group by. * */ private boolean isCountAllOnlyAggregate() { return gfields.length == 0 && internalAggs.size() == 1 && internalAggs.get(0) instanceof PrimitiveAggregator && ((PrimitiveAggregator) (internalAggs.get(0))).aggOp == AggregationOp.COUNT; } /** * @return A batch's worth of result tuples from this aggregate. * @throws DbException if there is an error. */ protected void generateResult() throws DbException { if (groupStates.numTuples() == 0) { return; } int stateOffset = gfields.length; for (Aggregator agg : internalAggs) { if (agg instanceof UserDefinedAggregator) { ((UserDefinedAggregator) agg).finalizePythonUpdaters(groupStates.getData(), stateOffset); } stateOffset += agg.getStateSize(); } Schema inputSchema = getChild().getSchema(); for (TupleBatch tb : groupStates.getData().getAll()) { List<Column<?>> columns = new ArrayList<Column<?>>(); columns.addAll(tb.getDataColumns().subList(0, gfields.length)); stateOffset = gfields.length; int emitOffset = 0; for (AggregatorFactory factory : factories) { int stateSize = factory.generateStateSchema(inputSchema).numColumns(); int emitSize = factory.generateSchema(inputSchema).numColumns(); TupleBatch state = tb.selectColumns(MyriaArrayUtils.range(stateOffset, stateSize)); for (GenericEvaluator eval : emitEvals.subList(emitOffset, emitOffset + emitSize)) { columns.add(eval.evalTupleBatch(state, getSchema()).getResultColumns().get(0)); } stateOffset += stateSize; emitOffset += emitSize; } addToResult(columns); } groupStates.cleanup(); } /** * @param columns result columns. */ protected void addToResult(List<Column<?>> columns) { resultBuffer.absorb(new TupleBatch(getSchema(), columns), true); } /** * The schema of the aggregate output. Grouping fields first and then aggregate fields. The aggregate * * @return the resulting schema */ @Override protected Schema generateSchema() { if (getChild() == null) { return null; } Schema inputSchema = getChild().getSchema(); if (inputSchema == null) { return null; } Schema aggSchema = Schema.EMPTY_SCHEMA; for (int i = 0; i < factories.length; ++i) { aggSchema = Schema.merge(aggSchema, factories[i].generateSchema(inputSchema)); } return Schema.merge(inputSchema.getSubSchema(gfields), aggSchema); } @Override protected void init(final ImmutableMap<String, Object> execEnvVars) throws DbException { Schema inputSchema = getChild().getSchema(); Preconditions.checkState(inputSchema != null, "unable to determine schema in init"); internalAggs = new ArrayList<Aggregator>(); emitEvals = new ArrayList<GenericEvaluator>(); Schema groupingSchema = inputSchema.getSubSchema(gfields); Schema stateSchema = Schema.EMPTY_SCHEMA; PythonFunctionRegistrar pyFuncReg = getPythonFunctionRegistrar(); for (AggregatorFactory factory : factories) { factory.setPyFuncReg(pyFuncReg); internalAggs.addAll(factory.generateInternalAggs(inputSchema)); List<Expression> emits = factory.generateEmitExpressions(inputSchema); Schema newStateSchema = factory.generateStateSchema(inputSchema); stateSchema = Schema.merge(stateSchema, newStateSchema); for (Expression exp : emits) { GenericEvaluator evaluator = null; if (exp.isRegisteredPythonUDF()) { evaluator = new PythonUDFEvaluator( exp, new ExpressionOperatorParameter( inputSchema, stateSchema, getPythonFunctionRegistrar())); } else { evaluator = new GenericEvaluator( exp, new ExpressionOperatorParameter( newStateSchema, newStateSchema, getPythonFunctionRegistrar())); } emitEvals.add(evaluator); } } groupStates = new UniqueTupleHashTable( Schema.merge(groupingSchema, stateSchema), MyriaArrayUtils.range(0, gfields.length)); resultBuffer = new TupleBatchBuffer(getSchema()); } };