/**
*
*/
package edu.washington.escience.myria.functions;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.lang.ProcessBuilder.Redirect;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import com.google.common.base.Preconditions;
import edu.washington.escience.myria.DbException;
import edu.washington.escience.myria.MyriaConstants;
import edu.washington.escience.myria.Type;
/**
*
*/
public class PythonWorker {
/***/
private static final long serialVersionUID = 1L;
/** logger */
private static final org.slf4j.Logger LOGGER =
org.slf4j.LoggerFactory.getLogger(PythonWorker.class);
/** server socket for python worker. */
private ServerSocket serverSocket = null;
/** client sock for python worker. */
private Socket clientSock = null;
/** python worker process. */
private Process worker = null;
/** output stream from python worker. */
private DataOutputStream dOut;
/** input stream from python worker. */
private DataInputStream dIn;
/**
* @throws DbException
*/
public PythonWorker() throws DbException {
try {
createServerSocket();
startPythonWorker();
} catch (Exception e) {
throw new DbException("Failed to create Python Worker");
}
}
/**
* @param pyCodeString - python function string
* @param numColumns number fo columns to be written to python process.
* @param outputType output type of the python function.
* @param isFlatMap does the python function return multiple tuples for a single input?
* @throws DbException in case of error.
*/
public void sendCodePickle(
final String pyCodeString,
final int numColumns,
final Type outputType,
final Boolean isFlatMap)
throws DbException {
Preconditions.checkNotNull(pyCodeString);
try {
if (pyCodeString.length() > 0 && dOut != null) {
byte[] bytes = pyCodeString.getBytes(StandardCharsets.UTF_8);
dOut.writeInt(bytes.length);
dOut.write(bytes);
dOut.writeInt(numColumns);
writeOutputType(outputType);
if (isFlatMap) {
dOut.writeInt(1);
} else {
dOut.writeInt(0);
}
dOut.flush();
} else {
throw new DbException("Can't write Python Code to worker!");
}
} catch (IOException e) {
LOGGER.debug("failed to send python code pickle");
throw new DbException(e);
}
}
/**
* @param numTuples number of tuples to be sent to python function.
* @throws IOException
* @throws DbException
*/
public void sendNumTuples(final int numTuples) throws DbException {
Preconditions.checkArgument(numTuples > 0, "number of tuples: %s", numTuples);
try {
dOut.writeInt(numTuples);
} catch (IOException e) {
throw new DbException(e);
}
}
/**
* @return dataoutput stream for the python worker.
*/
public DataOutputStream getDataOutputStream() {
Preconditions.checkNotNull(dOut);
return dOut;
}
/**
* @return dataInputStream for the python worker.
*/
public DataInputStream getDataInputStream() {
Preconditions.checkNotNull(dIn);
return dIn;
}
/**
* @throws IOException
*/
public void close() throws IOException {
if (clientSock != null) {
clientSock.close();
}
if (serverSocket != null) {
serverSocket.close();
}
// stop worker process
if (worker != null) {
worker.destroy();
}
}
/**
* @throws UnknownHostException
* @throws IOException
*/
private void createServerSocket() throws UnknownHostException, IOException {
serverSocket = new ServerSocket(0, 1, InetAddress.getByName("127.0.0.1"));
}
/**
* @throws IOException in case of error.
*/
private void startPythonWorker() throws IOException {
String pythonWorker = MyriaConstants.PYTHONWORKER;
ProcessBuilder pb = new ProcessBuilder(MyriaConstants.PYTHONEXEC, "-m", pythonWorker);
final Map<String, String> env = pb.environment();
env.put("PYTHONUNBUFFERED", "YES");
env.put("PYTHON_EGG_CACHE", "/tmp/.python-eggs");
pb.redirectError(Redirect.INHERIT);
pb.redirectOutput(Redirect.INHERIT);
// write the env variables to the path of the starting process
worker = pb.start();
OutputStream stdin = worker.getOutputStream();
OutputStreamWriter out = new OutputStreamWriter(stdin, StandardCharsets.UTF_8);
out.write(serverSocket.getLocalPort() + "\n");
out.flush();
clientSock = serverSocket.accept();
setupStreams();
}
/**
* @param outputType : output type for python function
* @throws IOException in case of error.
* @throws DbException in case of error.
*/
private void writeOutputType(final Type outputType) throws IOException, DbException {
switch (outputType) {
case DOUBLE_TYPE:
dOut.writeInt(MyriaConstants.PythonType.DOUBLE.getVal());
break;
case FLOAT_TYPE:
dOut.writeInt(MyriaConstants.PythonType.FLOAT.getVal());
break;
case INT_TYPE:
dOut.writeInt(MyriaConstants.PythonType.INT.getVal());
break;
case LONG_TYPE:
dOut.writeInt(MyriaConstants.PythonType.LONG.getVal());
break;
case BLOB_TYPE:
dOut.writeInt(MyriaConstants.PythonType.BLOB.getVal());
break;
default:
throw new DbException("Type not supported for python UDF ");
}
}
/**
* @throws IOException in case of error.
*/
private void setupStreams() throws IOException {
if (clientSock != null) {
dOut = new DataOutputStream(clientSock.getOutputStream());
dIn = new DataInputStream(clientSock.getInputStream());
}
}
}