package edu.washington.escience.myria.operator; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Objects; import javax.annotation.Nonnull; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.MyriaConstants; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.column.Column; import edu.washington.escience.myria.expression.Expression; import edu.washington.escience.myria.expression.evaluate.ConstantEvaluator; import edu.washington.escience.myria.expression.evaluate.ExpressionOperatorParameter; import edu.washington.escience.myria.expression.evaluate.GenericEvaluator; import edu.washington.escience.myria.expression.evaluate.GenericEvaluator.EvaluatorResult; import edu.washington.escience.myria.expression.evaluate.PythonUDFEvaluator; import edu.washington.escience.myria.storage.ReadableColumn; import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.storage.TupleBatchBuffer; import edu.washington.escience.myria.storage.TupleBuffer; /** * Generic apply operator for single- or multivalued expressions. */ public class Apply extends UnaryOperator { /***/ private static final long serialVersionUID = 1L; /** * List (possibly empty) of expressions that will be used to create the output. */ @Nonnull private ImmutableList<Expression> emitExpressions = ImmutableList.of(); /** * One evaluator for each expression in {@link #emitExpressions}. */ @Nonnull private ImmutableList<GenericEvaluator> emitEvaluators = ImmutableList.of(); /** * Buffer to hold finished and in-progress TupleBatches. */ private TupleBatchBuffer outputBuffer; /** * AddCounter to the returning tuplebatch. */ private Boolean addCounter = false; /** * @return the {@link #emitExpressions} */ protected ImmutableList<Expression> getEmitExpressions() { return emitExpressions; } /** * @return the {@link #emitEvaluators} */ public List<GenericEvaluator> getEmitEvaluators() { return emitEvaluators; } /** * @param evaluators the evaluators to set */ public void setEmitEvaluators(final List<GenericEvaluator> evaluators) { emitEvaluators = ImmutableList.copyOf(evaluators); } /** * @return if there are no multivalued emit expressions */ private boolean onlySingleValuedExpressions() { for (Expression expr : getEmitExpressions()) { if (expr.isMultiValued()) { return false; } } return true; } /** * @return number of columns that return more than one value for this Apply operator. */ private int numberOfMultiValuedExpressions() { int i = 0; for (Expression expr : getEmitExpressions()) { if (expr.isMultiValued()) { i += 1; } } return i; } /** * Should a counter be added? * * @return */ private boolean getAddCounter() { return (this.addCounter && (numberOfMultiValuedExpressions() == 1)); } private void setAddCounter(Boolean addCounter) { this.addCounter = addCounter; } /** * The logger for debug, trace, etc. messages in this class. */ private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(Apply.class); /** * @param child child operator that data is fetched from * @param emitExpressions expression that created the output */ public Apply(final Operator child, @Nonnull final List<Expression> emitExpressions) { super(child); Preconditions.checkNotNull(emitExpressions); setEmitExpressions(emitExpressions); } public Apply(final Operator child, List<Expression> emitExpressions, Boolean addCounter) { this(child, emitExpressions); if (addCounter != null) { setAddCounter(addCounter); } } /** * @param emitExpressions the emit expressions for each column */ private void setEmitExpressions(@Nonnull final List<Expression> emitExpressions) { this.emitExpressions = ImmutableList.copyOf(emitExpressions); } @Override protected TupleBatch fetchNextReady() throws DbException, InvocationTargetException, IOException { // If there's a batch already finished, return it, otherwise keep reading // batches from the child until we have a full batch or the child returns null. while (!outputBuffer.hasFilledTB()) { TupleBatch inputTuples = getChild().nextReady(); if (inputTuples != null) { if (onlySingleValuedExpressions()) { List<List<Column<?>>> tbs = new ArrayList<List<Column<?>>>(); for (final GenericEvaluator eval : emitEvaluators) { EvaluatorResult evalResult = eval.evalTupleBatch(inputTuples, getSchema()); List<Column<?>> cols = evalResult.getResultColumns(); for (int i = 0; i < cols.size(); ++i) { if (tbs.size() <= i) { tbs.add(new ArrayList<Column<?>>()); } tbs.get(i).add(cols.get(i)); } } for (List<Column<?>> tb : tbs) { outputBuffer.absorb(new TupleBatch(getSchema(), tb), true); } } else { // Evaluate expressions on each column and store counts and results. List<ReadableColumn> resultCountColumns = new ArrayList<>(); List<ReadableColumn> resultColumns = new ArrayList<>(); for (final GenericEvaluator eval : emitEvaluators) { EvaluatorResult evalResult = eval.evalTupleBatch(inputTuples, getSchema()); resultCountColumns.add(evalResult.getResultCounts()); resultColumns.add(evalResult.getResults()); } // Generate the Cartesian product and append to output buffer. int[] resultCounts = new int[emitEvaluators.size()]; int[] cumResultCounts = new int[emitEvaluators.size()]; int[] lastCumResultCounts = new int[emitEvaluators.size()]; int[] iteratorIndexes = new int[emitEvaluators.size()]; List<Type> types = Lists.newLinkedList(); types.add(Type.INT_TYPE); List<String> names = ImmutableList.of(MyriaConstants.FLATMAP_COLUMN_NAME); Schema countIdxSchema = new Schema(types, names); for (int rowIdx = 0; rowIdx < inputTuples.numTuples(); ++rowIdx) { // First, get all result counts for this row. boolean emptyProduct = false; for (int i = 0; i < resultCountColumns.size(); ++i) { int resultCount = resultCountColumns.get(i).getInt(rowIdx); resultCounts[i] = resultCount; lastCumResultCounts[i] = cumResultCounts[i]; cumResultCounts[i] += resultCounts[i]; if (resultCount == 0) { // If at least one evaluator returned zero results, the Cartesian product is empty. emptyProduct = true; } } if (!emptyProduct) { // Initialize each iterator to its starting index. Arrays.fill(iteratorIndexes, 0); // Iterate over each element of the Cartesian product and append to output TupleBuffer countIdx = null; int iteratorIdx = 0; int resultRowIdx = 0; int flatmapid = -1; do { for (iteratorIdx = 0; iteratorIdx < iteratorIndexes.length; ++iteratorIdx) { resultRowIdx = lastCumResultCounts[iteratorIdx] + iteratorIndexes[iteratorIdx]; if (getAddCounter() && flatmapid < iteratorIndexes[iteratorIdx]) { flatmapid = iteratorIndexes[iteratorIdx]; } outputBuffer.appendFromColumn( iteratorIdx, resultColumns.get(iteratorIdx), resultRowIdx); } if (getAddCounter()) { countIdx = new TupleBuffer(countIdxSchema, 1); countIdx.putInt(0, flatmapid); flatmapid = 0; outputBuffer.appendFromColumn(iteratorIdx, countIdx.asColumn(0), 0); } } while (!computeNextCombination(resultCounts, iteratorIndexes)); } } } } else { // We don't want to keep polling in a loop since this method is non-blocking. break; } } // If we produced a full batch, return it, otherwise finish the current batch and return it. return outputBuffer.popAny(); } /** * This method mutates {@link iteratorIndexes} on each call to yield the next element of the Cartesian product of * {@link upperBounds} in lexicographic order. If all elements have been exhausted, it returns true, otherwise it * returns false. * * @param upperBounds an immutable array of elements representing the sets we are forming the Cartesian product of, * where each set is of the form [0, i), where i is an element of {@link upperBounds} * @param iteratorIndexes a mutable array of elements representing the current element of the Cartesian product * @return if we have exhausted all elements of the Cartesian product */ private boolean computeNextCombination(final int[] upperBounds, final int[] iteratorIndexes) { boolean endOfIteration = false; int lastIteratorPos = iteratorIndexes.length - 1; // Count backward from the innermost iterator to the outermost. for (int iteratorPos = lastIteratorPos; iteratorPos >= 0; --iteratorPos) { // If the current iterator is not exhausted, increment it and exit the loop, // otherwise reset the current iterator and continue. if (iteratorIndexes[iteratorPos] < upperBounds[iteratorPos] - 1) { iteratorIndexes[iteratorPos] += 1; break; } else { // If the outermost iterator is exhausted, we are done. if (iteratorPos == 0) { endOfIteration = true; break; } else { // Reset the current iterator and continue. iteratorIndexes[iteratorPos] = 0; } } } return endOfIteration; } @Override protected void init(final ImmutableMap<String, Object> execEnvVars) throws DbException { Preconditions.checkNotNull(emitExpressions); Schema inputSchema = Objects.requireNonNull(getChild().getSchema()); List<GenericEvaluator> evals = new ArrayList<>(); final ExpressionOperatorParameter parameters = new ExpressionOperatorParameter( inputSchema, null, getNodeID(), getPythonFunctionRegistrar()); for (Expression expr : emitExpressions) { GenericEvaluator evaluator; if (expr.isConstant()) { evaluator = new ConstantEvaluator(expr, parameters); } else if (expr.isRegisteredPythonUDF()) { evaluator = new PythonUDFEvaluator(expr, parameters); } else { evaluator = new GenericEvaluator(expr, parameters); } Preconditions.checkArgument(!evaluator.needsState()); evals.add(evaluator); } setEmitEvaluators(evals); outputBuffer = new TupleBatchBuffer(generateSchema()); } @Override public Schema generateSchema() { Operator child = getChild(); if (child == null) { return null; } Schema inputSchema = child.getSchema(); if (inputSchema == null) { return null; } ImmutableList.Builder<Type> typesBuilder = ImmutableList.builder(); ImmutableList.Builder<String> namesBuilder = ImmutableList.builder(); for (Expression expr : emitExpressions) { typesBuilder.add(expr.getOutputType(new ExpressionOperatorParameter(inputSchema))); namesBuilder.add(expr.getOutputName()); } if (getAddCounter()) { typesBuilder.add(Type.INT_TYPE); namesBuilder.add(MyriaConstants.FLATMAP_COLUMN_NAME); } return new Schema(typesBuilder.build(), namesBuilder.build()); } }