/**
*
*/
package edu.washington.escience.myria.expression.evaluate;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.codehaus.janino.ExpressionEvaluator;
import com.google.common.base.Preconditions;
import com.gs.collections.api.iterator.IntIterator;
import com.gs.collections.impl.list.mutable.primitive.IntArrayList;
import com.gs.collections.impl.map.mutable.primitive.IntObjectHashMap;
import edu.washington.escience.myria.DbException;
import edu.washington.escience.myria.MyriaConstants;
import edu.washington.escience.myria.Schema;
import edu.washington.escience.myria.Type;
import edu.washington.escience.myria.api.encoding.FunctionStatus;
import edu.washington.escience.myria.column.builder.ColumnBuilder;
import edu.washington.escience.myria.column.builder.ColumnFactory;
import edu.washington.escience.myria.column.builder.WritableColumn;
import edu.washington.escience.myria.expression.Expression;
import edu.washington.escience.myria.expression.ExpressionOperator;
import edu.washington.escience.myria.expression.PyUDFExpression;
import edu.washington.escience.myria.expression.StateExpression;
import edu.washington.escience.myria.expression.VariableExpression;
import edu.washington.escience.myria.functions.PythonFunctionRegistrar;
import edu.washington.escience.myria.functions.PythonWorker;
import edu.washington.escience.myria.operator.Apply;
import edu.washington.escience.myria.operator.StatefulApply;
import edu.washington.escience.myria.storage.MutableTupleBuffer;
import edu.washington.escience.myria.storage.ReadableTable;
import edu.washington.escience.myria.storage.TupleBuffer;
/**
* An Expression evaluator for Python UDFs. Used in {@link Apply} and {@link StatefulApply}.
*/
public class PythonUDFEvaluator extends GenericEvaluator {
/** logger for this class. */
private static final org.slf4j.Logger LOGGER =
org.slf4j.LoggerFactory.getLogger(PythonUDFEvaluator.class);
/** python function registrar from which to fetch function pickle. */
private final PythonFunctionRegistrar pyFuncRegistrar;
/** python worker process. */
private PythonWorker pyWorker;
/** index of state column. */
private Set<Integer> stateColumns;
/** column indices of child ops. */
private int[] columnIdxs = null;
/** Output Type of the expression. */
private Type outputType = null;
/** is expression a flatmap? */
private Boolean isMultiValued = false;
/** Tuple buffers for each group key. */
private TupleBuffer buffer;
/** Mapping from row indices of state to rows in {@link PythonUDFEvaluator#buffer}. */
private IntObjectHashMap<IntArrayList> groups;
/** The internal state schema. */
private Schema stateSchema;
/**
* Default constructor.
*
* @param expression the expression for the evaluator
* @param parameters parameters that are passed to the expression
* @param pyFuncReg python function registrar to get the python function.
*/
public PythonUDFEvaluator(
final Expression expression, final ExpressionOperatorParameter parameters)
throws DbException {
super(expression, parameters);
pyFuncRegistrar = parameters.getPythonFunctionRegistrar();
if (pyFuncRegistrar == null) {
throw new RuntimeException("PythonRegistrar should not be null in PythonUDFEvaluator.");
}
PyUDFExpression op = (PyUDFExpression) expression.getRootExpressionOperator();
outputType = op.getOutputType(parameters);
List<ExpressionOperator> childops = op.getChildren();
columnIdxs = new int[childops.size()];
stateColumns = new HashSet<Integer>();
List<Type> types = new ArrayList<Type>();
for (int i = 0; i < childops.size(); i++) {
if (childops.get(i) instanceof StateExpression) {
stateColumns.add(i);
columnIdxs[i] = ((StateExpression) childops.get(i)).getColumnIdx();
types.add(((StateExpression) childops.get(i)).getOutputType(parameters));
} else if (childops.get(i) instanceof VariableExpression) {
columnIdxs[i] = ((VariableExpression) childops.get(i)).getColumnIdx();
types.add(((VariableExpression) childops.get(i)).getOutputType(parameters));
} else {
throw new IllegalStateException(
"Python expression can only have State or Variable expression as child expressions.");
}
}
stateSchema = new Schema(types);
String pyFunctionName = op.getName();
FunctionStatus fs = pyFuncRegistrar.getFunctionStatus(pyFunctionName);
if (fs == null) {
throw new DbException("No Python UDF with name " + pyFunctionName + " is registered.");
}
isMultiValued = fs.getIsMultiValued();
pyWorker = new PythonWorker();
pyWorker.sendCodePickle(fs.getBinary(), columnIdxs.length, outputType, isMultiValued);
buffer = new TupleBuffer(stateSchema);
groups = new IntObjectHashMap<IntArrayList>();
}
/**
* Creates an {@link ExpressionEvaluator} from the {@link #javaExpression}. This does not really compile the
* expression and is thus faster.
*/
@Override
public void compile() {
/* Do nothing! */
}
@Override
public void eval(
@Nonnull final ReadableTable input,
final int inputRow,
@Nullable final ReadableTable state,
final int stateRow,
@Nonnull final WritableColumn result,
@Nullable final WritableColumn count)
throws DbException {
pyWorker.sendNumTuples(1);
for (int i = 0; i < columnIdxs.length; ++i) {
if (stateColumns.contains(i)) {
writeToStream(state, stateRow, columnIdxs[i]);
} else {
writeToStream(input, inputRow, columnIdxs[i]);
}
}
readFromStream(count, result);
}
@Override
public void updateState(
@Nonnull final ReadableTable input,
final int inputRow,
@Nonnull final MutableTupleBuffer state,
final int stateRow,
final int stateColOffset)
throws DbException {
if (!groups.containsKey(stateRow)) {
groups.put(stateRow, new IntArrayList());
}
for (int i = 0; i < columnIdxs.length; ++i) {
if (stateColumns.contains(i)) {
buffer.put(i, state.asColumn(columnIdxs[i] + stateColOffset), stateRow);
} else {
buffer.put(i, input.asColumn(columnIdxs[i]), inputRow);
}
}
IntArrayList indices = groups.get(stateRow);
indices.add(buffer.numTuples() - 1);
};
/**
* @param state state
* @param col column index of the state to be written to.
* @throws DbException in case of error
*/
public void evalGroups(final MutableTupleBuffer state, final int col) throws DbException {
IntIterator keyIter = groups.keySet().intIterator();
while (keyIter.hasNext()) {
int key = keyIter.next();
pyWorker.sendNumTuples(groups.get(key).size());
IntIterator rowIter = groups.get(key).intIterator();
while (rowIter.hasNext()) {
int row = rowIter.next();
for (int i = 0; i < buffer.numColumns(); ++i) {
writeToStream(buffer, row, i);
}
}
ColumnBuilder<?> output = ColumnFactory.allocateColumn(outputType);
/* TODO: Leaving the count column to be null for now since since it's not used by Python evaluator for aggregate.
* A better design is to let the Aggregator emit two columns or even multiple columns. */
readFromStream(null, output);
if (output.size() > 1) {
throw new RuntimeException("PythonUDFEvaluator cannot be multivalued for Aggregate");
}
for (int i = 0; i < output.size(); ++i) {
state.replace(col, key, output, i);
}
}
}
/**
* @param count number of tuples returned.
* @param result writable column
* @param result2 appendable table
* @param resultColIdx id of the result column.
* @throws DbException in case of error.
*/
public void readFromStream(final WritableColumn count, final WritableColumn result)
throws DbException {
DataInputStream dIn = pyWorker.getDataInputStream();
int c = 1; // single valued expressions only return 1 tuple.
try {
// if it is a flat map operation, read number of tuples to be read.
if (isMultiValued) {
c = dIn.readInt();
count.appendInt(c);
} else { // not flatmap
if (count != null) { // count column is not null
count.appendInt(1);
}
}
for (int i = 0; i < c; i++) {
// then read the type of tuple
int type = dIn.readInt();
// if the 'type' is exception, throw exception
if (type == MyriaConstants.PythonSpecialLengths.PYTHON_EXCEPTION.getVal()) {
int excepLength = dIn.readInt();
byte[] excp = new byte[excepLength];
dIn.readFully(excp);
throw new DbException(new String(excp));
} else {
// read the rest of the tuple
if (type == MyriaConstants.PythonType.DOUBLE.getVal()) {
result.appendDouble(dIn.readDouble());
} else if (type == MyriaConstants.PythonType.FLOAT.getVal()) {
result.appendFloat(dIn.readFloat());
} else if (type == MyriaConstants.PythonType.INT.getVal()) {
result.appendInt(dIn.readInt());
} else if (type == MyriaConstants.PythonType.LONG.getVal()) {
result.appendLong(dIn.readLong());
} else if (type == MyriaConstants.PythonType.BLOB.getVal()) {
int l = dIn.readInt();
if (l > 0) {
byte[] obj = new byte[l];
dIn.readFully(obj);
result.appendBlob(ByteBuffer.wrap(obj));
}
} else {
throw new DbException("Type not supported by python");
}
}
}
} catch (Exception e) {
throw new DbException(e);
}
}
/**
* helper function to write to python process.
*
* @param tb - input tuple buffer.
* @param row - row being evaluated.
* @param columnIdx - column to be written to the py process.
* @throws DbException in case of error.
*/
private void writeToStream(@Nonnull final ReadableTable tb, final int row, final int columnIdx)
throws DbException {
DataOutputStream dOut = pyWorker.getDataOutputStream();
Preconditions.checkNotNull(tb, "input tuple cannot be null");
Preconditions.checkNotNull(dOut, "Output stream for python process cannot be null");
try {
switch (tb.getSchema().getColumnType(columnIdx)) {
case BOOLEAN_TYPE:
LOGGER.debug("BOOLEAN type not supported for python function");
break;
case DOUBLE_TYPE:
dOut.writeInt(MyriaConstants.PythonType.DOUBLE.getVal());
dOut.writeInt(Double.SIZE / Byte.SIZE);
dOut.writeDouble(tb.getDouble(columnIdx, row));
break;
case FLOAT_TYPE:
dOut.writeInt(MyriaConstants.PythonType.FLOAT.getVal());
dOut.writeInt(Float.SIZE / Byte.SIZE);
dOut.writeFloat(tb.getFloat(columnIdx, row));
break;
case INT_TYPE:
dOut.writeInt(MyriaConstants.PythonType.INT.getVal());
dOut.writeInt(Integer.SIZE / Byte.SIZE);
dOut.writeInt(tb.getInt(columnIdx, row));
break;
case LONG_TYPE:
dOut.writeInt(MyriaConstants.PythonType.LONG.getVal());
dOut.writeInt(Long.SIZE / Byte.SIZE);
dOut.writeLong(tb.getLong(columnIdx, row));
break;
case STRING_TYPE:
LOGGER.debug("STRING type is not yet supported for python function ");
break;
case DATETIME_TYPE:
LOGGER.debug("date time not yet supported for python function ");
break;
case BLOB_TYPE:
dOut.writeInt(MyriaConstants.PythonType.BLOB.getVal());
ByteBuffer input = tb.getBlob(columnIdx, row);
if (input != null && input.hasArray()) {
dOut.writeInt(input.array().length);
dOut.write(input.array());
} else {
dOut.writeInt(MyriaConstants.PythonSpecialLengths.NULL_LENGTH.getVal());
}
}
dOut.flush();
} catch (Exception e) {
throw new DbException(e);
}
}
}