package edu.washington.escience.myria.operator.agg;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import javax.annotation.Nonnull;
import org.codehaus.commons.compiler.CompilerFactoryFactory;
import org.codehaus.commons.compiler.IScriptEvaluator;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import edu.washington.escience.myria.DbException;
import edu.washington.escience.myria.MyriaConstants;
import edu.washington.escience.myria.Schema;
import edu.washington.escience.myria.Type;
import edu.washington.escience.myria.expression.Expression;
import edu.washington.escience.myria.expression.evaluate.ExpressionOperatorParameter;
import edu.washington.escience.myria.expression.evaluate.GenericEvaluator;
import edu.washington.escience.myria.expression.evaluate.PythonUDFEvaluator;
import edu.washington.escience.myria.functions.PythonFunctionRegistrar;
/**
* Apply operator that has to be initialized and carries a state while new tuples are generated.
*/
public class UserDefinedAggregatorFactory implements AggregatorFactory {
/** Required for Java serialization. */
private static final long serialVersionUID = 1L;
/** logger for this class. */
private static final org.slf4j.Logger LOGGER =
org.slf4j.LoggerFactory.getLogger(UserDefinedAggregatorFactory.class);
/** Expressions that initialize the state variables. */
@JsonProperty private final List<Expression> initializers;
/** Expressions that update the state variables as a function of the input and the current tuple. */
@JsonProperty private final List<Expression> updaters;
/** Expressions that emit the final aggregation result from the state. */
@JsonProperty private final List<Expression> emitters;
/** Evaluators that initialize the {@link #state}. */
private List<GenericEvaluator> initEvaluators;
/** Evaluators that update the {@link #state}. One evaluator for each expression in {@link #updaters}. */
private List<GenericEvaluator> updateEvaluators;
/** The schema of the result tuples. */
private Schema resultSchema;
/**
* Construct a new user-defined aggregate. The initializers set the initial state of the aggregate; the updaters
* update this state for every new tuple. The emitters produce the final value of the aggregate. Note that there must
* be the same number of initializers and updaters, but there may be any number > 0 of emitters.
*
* @param initializers Expressions that initialize the state variables.
* @param updaters Expressions that update the state variables as a function of the input and the current tuple.
* @param emitters Expressions that emit the final aggregation result from the state.
*/
@JsonCreator
public UserDefinedAggregatorFactory(
@JsonProperty(value = "initializers", required = true) final List<Expression> initializers,
@JsonProperty(value = "updaters", required = true) final List<Expression> updaters,
@JsonProperty(value = "emitters", required = true) final List<Expression> emitters) {
this.initializers = Objects.requireNonNull(initializers, "initializers");
this.updaters = Objects.requireNonNull(updaters, "updaters");
this.emitters = Objects.requireNonNull(emitters, "emitters");
resultSchema = null;
updateEvaluators = new ArrayList<GenericEvaluator>();
initEvaluators = new ArrayList<GenericEvaluator>();
}
@Override
public List<Expression> generateEmitExpressions(final Schema inputSchema) throws DbException {
return emitters;
}
@Override
public List<Aggregator> generateInternalAggs(final Schema inputSchema) throws DbException {
/* Initialize the state. */
for (int i = 0; i < initializers.size(); ++i) {
initEvaluators.add(
getEvaluator(initializers.get(i), new ExpressionOperatorParameter(inputSchema), i));
}
/* Set up the updaters. */
Schema stateSchema = generateStateSchema(inputSchema);
ExpressionOperatorParameter param =
new ExpressionOperatorParameter(inputSchema, stateSchema, pyFuncReg);
for (int i = 0; i < updaters.size(); ++i) {
if (updaters.get(i).isRegisteredPythonUDF()) {
updateEvaluators.add(new PythonUDFEvaluator(updaters.get(i), param));
} else {
updateEvaluators.add(getEvaluator(updaters.get(i), param, i));
}
}
/* Compute the result schema. */
ExpressionOperatorParameter emitParams = new ExpressionOperatorParameter(null, stateSchema);
ImmutableList.Builder<Type> types = ImmutableList.builder();
ImmutableList.Builder<String> names = ImmutableList.builder();
for (Expression e : emitters) {
types.add(e.getOutputType(emitParams));
names.add(e.getOutputName());
}
resultSchema = new Schema(types, names);
return ImmutableList.of(
new UserDefinedAggregator(initEvaluators, updateEvaluators, resultSchema, stateSchema));
}
/**
* Produce a {@link GenericEvaluator} from {@link Expression} and {@link ExpressionOperatorParameter}s. This function
* produces the code for a Java script that executes all expressions in turn and appends the calculated values to the
* result. The values to be output are calculated completely before they are stored to the output, thus it is safe to
* pass the same object as input and output, e.g., in the case of updating state in an Aggregate.
*
* @param expressions one expression for each output column.
* @param param the inputs that expressions may use, including the {@link Schema} of the expression inputs and
* worker-local variables.
* @param col the column index of the expression.
* @return a compiled object that will run all the expressions and store them into the output.
* @throws DbException if there is an error compiling the expressions.
*/
private GenericEvaluator getEvaluator(
@Nonnull final Expression expr,
@Nonnull final ExpressionOperatorParameter param,
final int col)
throws DbException {
StringBuilder compute = new StringBuilder();
Type type = expr.getOutputType(param);
// <TYPE> val<I> = <EXPRESSION>;
compute
.append(type.toJavaType().getName())
.append(" val")
.append(col)
.append(" = ")
.append(expr.getJavaExpression(param))
.append(";\n");
if (param.getStateSchema() == null) {
// state.put<TYPE>(<I>, val<I>);
compute
.append(Expression.STATE)
.append(".put")
.append(type == Type.BLOB_TYPE ? "Blob" : type.toJavaObjectType().getSimpleName())
.append("(")
.append(col)
.append("+")
.append(Expression.STATECOLOFFSET)
.append(", val")
.append(col)
.append(");\n");
} else {
// state.replace<TYPE>(<I> + stateColOffset, stateRow, val<I>);
compute
.append(Expression.STATE)
.append(".replace")
.append(type == Type.BLOB_TYPE ? "Blob" : type.toJavaObjectType().getSimpleName())
.append("(")
.append(col)
.append("+")
.append(Expression.STATECOLOFFSET)
.append(", ")
.append(Expression.STATEROW)
.append(", val")
.append(col)
.append(");\n");
}
String script = compute.toString();
LOGGER.debug("Compiling UDA {}", script);
IScriptEvaluator se;
try {
se = CompilerFactoryFactory.getDefaultCompilerFactory().newScriptEvaluator();
} catch (Exception e) {
LOGGER.debug("Could not create scriptevaluator", e);
throw new DbException("Could not create scriptevaluator", e);
}
se.setDefaultImports(MyriaConstants.DEFAULT_JANINO_IMPORTS);
GenericEvaluator eval = new GenericEvaluator(expr, script, param);
return eval;
}
@Override
public Schema generateSchema(final Schema inputSchema) {
Schema stateSchema = generateStateSchema(inputSchema);
ImmutableList.Builder<Type> typesBuilder = ImmutableList.builder();
ImmutableList.Builder<String> namesBuilder = ImmutableList.builder();
for (Expression expr : emitters) {
typesBuilder.add(
expr.getOutputType(new ExpressionOperatorParameter(inputSchema, stateSchema)));
namesBuilder.add(expr.getOutputName());
}
return Schema.of(typesBuilder.build(), namesBuilder.build());
}
@Override
public Schema generateStateSchema(final Schema inputSchema) {
ImmutableList.Builder<Type> typesBuilder = ImmutableList.builder();
ImmutableList.Builder<String> namesBuilder = ImmutableList.builder();
for (Expression expr : initializers) {
typesBuilder.add(expr.getOutputType(new ExpressionOperatorParameter(inputSchema)));
namesBuilder.add(expr.getOutputName());
}
return Schema.of(typesBuilder.build(), namesBuilder.build());
}
PythonFunctionRegistrar pyFuncReg;
public void setPyFuncReg(PythonFunctionRegistrar pyFuncReg) {
this.pyFuncReg = pyFuncReg;
}
}