package edu.washington.escience.myria.operator.agg; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Set; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.expression.DivideExpression; import edu.washington.escience.myria.expression.Expression; import edu.washington.escience.myria.expression.ExpressionOperator; import edu.washington.escience.myria.expression.MinusExpression; import edu.washington.escience.myria.expression.SqrtExpression; import edu.washington.escience.myria.expression.TimesExpression; import edu.washington.escience.myria.expression.VariableExpression; import edu.washington.escience.myria.functions.PythonFunctionRegistrar; import edu.washington.escience.myria.operator.agg.PrimitiveAggregator.AggregationOp; /** * A factory that generates aggregators for a primitive column. */ public class PrimitiveAggregatorFactory implements AggregatorFactory { /** Required for Java serialization. */ private static final long serialVersionUID = 1L; /** Which column of the input to aggregate over. */ @JsonProperty private final int column; /** Which aggregate options are requested. See {@link PrimitiveAggregator}. */ @JsonProperty private final AggregationOp[] aggOps; /** * A wrapper for the {@link PrimitiveAggregator} implementations like {@link IntegerAggregator}. * * @param column which column of the input to aggregate over. * @param aggOps which aggregate operations are requested. See {@link PrimitiveAggregator}. */ @JsonCreator public PrimitiveAggregatorFactory( @JsonProperty(value = "column", required = true) final Integer column, @JsonProperty(value = "aggOps", required = true) final AggregationOp[] aggOps) { this.column = Objects.requireNonNull(column, "column").intValue(); this.aggOps = Objects.requireNonNull(aggOps, "aggOps"); Preconditions.checkNotNull(aggOps, "aggregation operator %s cannot be null"); } /** * @param column which column of the input to aggregate over. * @param aggOp which aggregate is requested. */ public PrimitiveAggregatorFactory(final Integer column, final AggregationOp aggOp) { this(column, new AggregationOp[] {aggOp}); } @Override public List<Aggregator> generateInternalAggs(final Schema inputSchema) { List<Aggregator> ret = new ArrayList<Aggregator>(); List<AggregationOp> ops = getInternalOps(); for (int i = 0; i < ops.size(); ++i) { ret.add(generateAgg(inputSchema, ops.get(i))); } return ret; } /** * @param inputSchema the input schema * @param aggOp the aggregation op * @param indices the column indices of this aggregator in the state hash table * @return the generated aggregator */ private Aggregator generateAgg(final Schema inputSchema, final AggregationOp aggOp) { String inputName = inputSchema.getColumnName(column); Type type = inputSchema.getColumnType(column); switch (type) { case BOOLEAN_TYPE: return new BooleanAggregator(inputName, column, aggOp); case DATETIME_TYPE: return new DateTimeAggregator(inputName, column, aggOp); case DOUBLE_TYPE: return new DoubleAggregator(inputName, column, aggOp); case FLOAT_TYPE: return new FloatAggregator(inputName, column, aggOp); case INT_TYPE: return new IntegerAggregator(inputName, column, aggOp); case LONG_TYPE: return new LongAggregator(inputName, column, aggOp); case STRING_TYPE: return new StringAggregator(inputName, column, aggOp); default: throw new IllegalArgumentException("Unknown column type: " + type); } } @Override public List<Expression> generateEmitExpressions(final Schema inputSchema) { List<AggregationOp> cols = getInternalOps(); List<Expression> exps = new ArrayList<Expression>(); for (int i = 0; i < aggOps.length; ++i) { String name = aggOps[i].toString().toLowerCase() + "_" + inputSchema.getColumnName(column); switch (aggOps[i]) { case COUNT: case MIN: case MAX: case SUM: exps.add(new Expression(name, new VariableExpression(cols.indexOf(aggOps[i])))); continue; case AVG: exps.add( new Expression( name, new DivideExpression( new VariableExpression(cols.indexOf(AggregationOp.SUM)), new VariableExpression(cols.indexOf(AggregationOp.COUNT))))); continue; case STDEV: ExpressionOperator sumExp = new VariableExpression(cols.indexOf(AggregationOp.SUM)); ExpressionOperator countExp = new VariableExpression(cols.indexOf(AggregationOp.COUNT)); ExpressionOperator sumSquaredExp = new VariableExpression(cols.indexOf(AggregationOp.SUM_SQUARED)); ExpressionOperator first = new DivideExpression(sumSquaredExp, countExp); ExpressionOperator second = new DivideExpression(sumExp, countExp); exps.add( new Expression( name, new SqrtExpression( new MinusExpression(first, new TimesExpression(second, second))))); continue; default: throw new IllegalArgumentException("Type " + aggOps[i] + " is invalid"); } } return exps; } /** * Generate the internal aggregation ops. Each used op corresponds to one column. * * @return the list of aggregation ops. */ public List<AggregationOp> getInternalOps() { Set<AggregationOp> colTypes = new HashSet<AggregationOp>(); for (AggregationOp aggOp : aggOps) { colTypes.addAll(getInternalOps(aggOp)); } List<AggregationOp> ret = new ArrayList<AggregationOp>(colTypes); Collections.sort(ret); return ret; } /** * @param op the emit aggregation op * @return the internal aggregation ops needed for computing the emit op */ private List<AggregationOp> getInternalOps(AggregationOp op) { switch (op) { case COUNT: case MIN: case MAX: case SUM: return ImmutableList.of(op); case AVG: return ImmutableList.of(AggregationOp.SUM, AggregationOp.COUNT); case STDEV: return ImmutableList.of(AggregationOp.SUM, AggregationOp.SUM_SQUARED, AggregationOp.COUNT); default: throw new IllegalArgumentException("Type " + op + " is invalid"); } } /** * @param input the input type * @param op the aggregation op * @return the output type of applying op on the input type */ public Type getAggColumnType(Type input, AggregationOp op) { switch (op) { case MIN: case MAX: return input; case COUNT: return Type.LONG_TYPE; case SUM: case SUM_SQUARED: if (input == Type.INT_TYPE || input == Type.LONG_TYPE) { return Type.LONG_TYPE; } if (input == Type.FLOAT_TYPE || input == Type.DOUBLE_TYPE) { return Type.DOUBLE_TYPE; } throw new IllegalArgumentException(op + " on " + input + " is invalid"); case AVG: case STDEV: return Type.DOUBLE_TYPE; default: throw new IllegalArgumentException(op + " on " + input + " is invalid"); } } @Override public Schema generateSchema(final Schema inputSchema) { List<String> names = new ArrayList<String>(); List<Type> types = new ArrayList<Type>(); for (AggregationOp op : aggOps) { types.add(getAggColumnType(inputSchema.getColumnType(column), op)); names.add(op.toString().toLowerCase() + "_" + inputSchema.getColumnName(column)); } return Schema.of(types, names); } @Override public Schema generateStateSchema(final Schema inputSchema) { List<String> names = new ArrayList<String>(); List<Type> types = new ArrayList<Type>(); for (AggregationOp op : getInternalOps()) { types.add(getAggColumnType(inputSchema.getColumnType(column), op)); names.add(op.toString().toLowerCase() + "_" + inputSchema.getColumnName(column)); } return Schema.of(types, names); } PythonFunctionRegistrar pyFuncReg; public void setPyFuncReg(PythonFunctionRegistrar pyFuncReg) { this.pyFuncReg = pyFuncReg; } }