package edu.washington.escience.myria.operator; import java.io.IOException; import java.util.ArrayList; import java.util.List; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.column.Column; import edu.washington.escience.myria.column.builder.ColumnBuilder; import edu.washington.escience.myria.column.builder.ColumnFactory; import edu.washington.escience.myria.expression.Expression; import edu.washington.escience.myria.expression.evaluate.ConstantEvaluator; import edu.washington.escience.myria.expression.evaluate.ExpressionOperatorParameter; import edu.washington.escience.myria.expression.evaluate.GenericEvaluator; import edu.washington.escience.myria.expression.evaluate.PythonUDFEvaluator; import edu.washington.escience.myria.storage.Tuple; import edu.washington.escience.myria.storage.TupleBatch; /** * Apply operator that has to be initialized and carries a state while new tuples are generated. */ public class StatefulApply extends Apply { /***/ private static final long serialVersionUID = 1L; /** logger for this class. */ private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(StatefulApply.class); /** * Expressions that are used to initialize the state. */ private ImmutableList<Expression> initExpressions; /** * Expressions that are used to update the state. */ private ImmutableList<Expression> updateExpressions; /** * The states that are passed during execution. */ private Tuple state; /** * Evaluators that update the {@link #state}. One evaluator for each expression in {@link #updateExpressions}. */ private ArrayList<GenericEvaluator> updateEvaluators; /** * Schema of the state relation. */ private Schema stateSchema = null; /** * @param child child operator that data is fetched from * @param emitExpression expressions that creates the output * @param initializerExpressions expressions that initializes the state * @param updaterExpressions expressions that update the state */ public StatefulApply( final Operator child, final List<Expression> emitExpression, final List<Expression> initializerExpressions, final List<Expression> updaterExpressions) { super(child, emitExpression); Preconditions.checkArgument(initializerExpressions.size() == updaterExpressions.size()); for (int i = 0; i < initializerExpressions.size(); i++) { Preconditions.checkArgument( updaterExpressions.get(i).getOutputName() == null || initializerExpressions .get(i) .getOutputName() .equals(updaterExpressions.get(i).getOutputName())); } setInitExpressions(initializerExpressions); setUpdateExpressions(updaterExpressions); } /** * @param initializerExpressions the expressions that initialize the state */ private void setInitExpressions(final List<Expression> initializerExpressions) { initExpressions = ImmutableList.copyOf(initializerExpressions); } /** * @param updaterExpressions the expressions that update the state */ private void setUpdateExpressions(final List<Expression> updaterExpressions) { updateExpressions = ImmutableList.copyOf(updaterExpressions); } @Override protected TupleBatch fetchNextReady() throws DbException, IOException { Operator child = getChild(); if (child.eoi() || getChild().eos()) { return null; } TupleBatch tb = child.nextReady(); if (tb == null) { return null; } final int numColumns = getSchema().numColumns(); List<Column<?>> output = Lists.newArrayList(new Column<?>[numColumns]); List<Integer> needState = Lists.newLinkedList(); // first, generate columns that do not require state. This can often be optimized. for (int columnIdx = 0; columnIdx < numColumns; columnIdx++) { final GenericEvaluator evaluator = getEmitEvaluators().get(columnIdx); Preconditions.checkArgument( !evaluator.getExpression().isMultiValued(), "A multivalued expression cannot be used in StatefulApply."); if (!evaluator.needsState() || evaluator.isCopyFromInput()) { output.set(columnIdx, evaluator.evalTupleBatch(tb, getSchema()).getResultColumns().get(0)); } else { needState.add(columnIdx); } } // second, build and add the columns that require state List<ColumnBuilder<?>> columnBuilders = Lists.newArrayListWithCapacity(needState.size()); for (int builderIdx = 0; builderIdx < needState.size(); builderIdx++) { columnBuilders.add( ColumnFactory.allocateColumn( getEmitEvaluators().get(needState.get(builderIdx)).getOutputType())); } for (int rowIdx = 0; rowIdx < tb.numTuples(); rowIdx++) { // update state Tuple newState = new Tuple(getStateSchema()); for (int columnIdx = 0; columnIdx < stateSchema.numColumns(); columnIdx++) { updateEvaluators .get(columnIdx) .eval(tb, rowIdx, state, 0, newState.asWritableColumn(columnIdx), null); } state = newState; // apply expression for (int index = 0; index < needState.size(); index++) { final GenericEvaluator evaluator = getEmitEvaluators().get(needState.get(index)); // TODO: optimize the case where the state is copied directly evaluator.eval(tb, rowIdx, state, 0, columnBuilders.get(index), null); } } for (int builderIdx = 0; builderIdx < needState.size(); builderIdx++) { output.set(needState.get(builderIdx), columnBuilders.get(builderIdx).build()); } return new TupleBatch(getSchema(), output); } @Override protected void init(final ImmutableMap<String, Object> execEnvVars) throws DbException { Preconditions.checkArgument(initExpressions.size() == updateExpressions.size()); Preconditions.checkNotNull(getEmitExpressions()); final Schema inputSchema = getChild().getSchema(); ArrayList<GenericEvaluator> evaluators = new ArrayList<>(); evaluators.ensureCapacity(getEmitExpressions().size()); // initialize evaluators for Emit expressions. // these can only be generic or python expressions. for (Expression expr : getEmitExpressions()) { GenericEvaluator evaluator; if (expr.isConstant()) { evaluator = new ConstantEvaluator(expr, new ExpressionOperatorParameter(inputSchema, getNodeID())); } else if (expr.isRegisteredPythonUDF()) { evaluator = new PythonUDFEvaluator( expr, new ExpressionOperatorParameter( inputSchema, getStateSchema(), getNodeID(), getPythonFunctionRegistrar())); } else { evaluator = new GenericEvaluator( expr, new ExpressionOperatorParameter(inputSchema, getStateSchema(), getNodeID())); } evaluators.add(evaluator); } setEmitEvaluators(evaluators); updateEvaluators = new ArrayList<>(); updateEvaluators.ensureCapacity(updateExpressions.size()); state = new Tuple(getStateSchema()); // initialize init evaluators. these could be constant expressions only! for (int columnIdx = 0; columnIdx < getStateSchema().numColumns(); columnIdx++) { Expression expr = initExpressions.get(columnIdx); ConstantEvaluator evaluator = new ConstantEvaluator(expr, new ExpressionOperatorParameter(inputSchema, getNodeID())); evaluator.compile(); state.putObject(columnIdx, evaluator.eval()); } // initialize update evaluators -- these can be generic or python evaluators for (Expression expr : updateExpressions) { GenericEvaluator evaluator; if (expr.isRegisteredPythonUDF()) { evaluator = new PythonUDFEvaluator( expr, new ExpressionOperatorParameter( inputSchema, getStateSchema(), getNodeID(), getPythonFunctionRegistrar())); } else { evaluator = new GenericEvaluator( expr, new ExpressionOperatorParameter(inputSchema, getStateSchema(), getNodeID())); } evaluator.compile(); updateEvaluators.add(evaluator); } } /** * @return The schema of the state relation. */ private Schema getStateSchema() { if (stateSchema == null) { return generateStateSchema(); } return stateSchema; } /** * Generates the state schema and returns it. * * @return the state schema */ private Schema generateStateSchema() { ImmutableList.Builder<Type> typesBuilder = ImmutableList.builder(); ImmutableList.Builder<String> namesBuilder = ImmutableList.builder(); for (Expression expr : initExpressions) { typesBuilder.add(expr.getOutputType(new ExpressionOperatorParameter(getChild().getSchema()))); namesBuilder.add(expr.getOutputName()); } stateSchema = new Schema(typesBuilder.build(), namesBuilder.build()); return stateSchema; } @Override public Schema generateSchema() { if (getEmitExpressions() == null) { return null; } Operator child = getChild(); if (child == null) { return null; } Schema inputSchema = child.getSchema(); if (inputSchema == null) { return null; } if (getStateSchema() == null) { return null; } ImmutableList.Builder<Type> typesBuilder = ImmutableList.builder(); ImmutableList.Builder<String> namesBuilder = ImmutableList.builder(); for (Expression expr : getEmitExpressions()) { typesBuilder.add( expr.getOutputType(new ExpressionOperatorParameter(inputSchema, getStateSchema()))); namesBuilder.add(expr.getOutputName()); } return new Schema(typesBuilder.build(), namesBuilder.build()); } }