package com.thinkbiganalytics.spark.repl;
/*-
* #%L
* thinkbig-commons-spark-repl
* %%
* Copyright (C) 2017 ThinkBig Analytics
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import com.google.common.base.Charsets;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import com.thinkbiganalytics.spark.util.ArrayUtils;
import org.apache.commons.io.output.ByteArrayOutputStream;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.hive.HiveContext;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.script.ScriptException;
import scala.tools.nsc.interpreter.NamedParam;
/**
* Interface for an interpreter that compiles and evaluates Scala code containing a Spark job.
*
* <p>Scripts may access a {@link SparkContext} through the {@code sc} variable and a {@link SQLContext} through the
* {@code sqlContext} variable.</p>
*
* <p>This class is <i>thread-safe</i> and ensures that only one script </p>
*/
public abstract class ScriptEngine {
/**
* End of line character
*/
private static final byte[] END_LINE = new byte[]{'\n'};
/**
* Label used by the compiler to indicate a compile error
*/
private static final byte[] LABEL = "<console>".getBytes(Charsets.UTF_8);
/**
* Separator between label, line number, and error message
*/
private static final byte[] SEPARATOR = new byte[]{':'};
/**
* Exception thrown by the last script
*/
@Nonnull
private final AtomicReference<Throwable> exception = new AtomicReference<>();
/**
* Compiler output stream for capturing compile errors
*/
@Nonnull
private final ByteArrayOutputStream out = new ByteArrayOutputStream();
/**
* Result of the last script
*/
@Nonnull
private final AtomicReference<Object> result = new AtomicReference<>();
/**
* Map of variable names to values for bindings
*/
@Nonnull
private final Map<String, Object> values = Maps.newHashMap();
/**
* Spark context
*/
@Nullable
private SparkContext sparkContext;
/**
* Spark SQL context
*/
@Nullable
private SQLContext sqlContext;
/**
* Executes the specified script.
*
* @param script the script to be executed
* @return the value returned from the script
* @throws ScriptException if an error occurs in the script
*/
@Nullable
public synchronized Object eval(@Nonnull final String script) throws ScriptException {
List<NamedParam> bindings = ImmutableList.of();
return eval(script, bindings);
}
/**
* Executes the specified script with the given bindings.
*
* @param script the script to be executed
* @param bindings the variable bindings to be accessible to the script
* @return the value returned from the script
* @throws ScriptException if an error occurs in the script
*/
@Nullable
public synchronized Object eval(@Nonnull final String script, @Nonnull final List<NamedParam> bindings)
throws ScriptException {
// Define class containing script
final StringBuilder cls = new StringBuilder();
cls.append("class Script (engine: com.thinkbiganalytics.spark.repl.ScriptEngine)");
cls.append(" extends com.thinkbiganalytics.spark.repl.Script (engine) {\n");
cls.append(" override def eval (): Any = {\n");
cls.append(script);
cls.append(" }\n");
// Add bindings to class
this.values.clear();
for (NamedParam param : bindings) {
cls.append(" def ");
cls.append(param.name());
cls.append(" (): ");
cls.append(param.tpe());
cls.append(" = getValue(\"");
cls.append(param.name());
cls.append("\")\n");
this.values.put(param.name(), param.value());
}
cls.append("}\n");
// Instantiate class
cls.append("new Script(engine).run()\n");
// Execute script
this.out.reset();
execute(cls.toString());
// Check for exception and return result
checkCompileError();
checkRuntimeError();
return this.result.get();
}
/**
* Gets the class loader used by the interpreter.
*
* @return the class loader
*/
@Nonnull
public abstract ClassLoader getClassLoader();
/**
* Gets the {@code SparkContext} available to scripts as {@code sc}.
*
* @return the Spark context
*/
@Nonnull
public SparkContext getSparkContext() {
if (this.sparkContext == null) {
this.sparkContext = createSparkContext();
}
return this.sparkContext;
}
/**
* Gets the {@code SQLContext} available to scripts as {@code sqlContext}.
*
* @return the SQL context
*/
@Nonnull
public SQLContext getSQLContext() {
if (this.sqlContext == null) {
this.sqlContext = new HiveContext(getSparkContext());
}
return this.sqlContext;
}
/**
* Creates the {@code SparkContext} that will be available to scripts as {@code sc}.
*
* @return the Spark context
*/
@Nonnull
protected abstract SparkContext createSparkContext();
/**
* Executes the specified script.
*
* @param script the script to be executed
* @throws ScriptException if an error occurs in the script
*/
protected abstract void execute(@Nonnull final String script) throws ScriptException;
/**
* Gets the writer for capturing compile errors.
*
* @return the compiler output stream
*/
protected PrintWriter getPrintWriter() {
return new PrintWriter(this.out);
}
/**
* Resets the engine state so the {@link SparkContext} can be recreated.
*/
protected void reset() {
// Stop Spark
if (sparkContext != null && !sparkContext.isStopped()) {
sparkContext.stop();
}
// Clear instance variables
exception.set(null);
out.reset();
result.set(null);
sparkContext = null;
sqlContext = null;
}
/**
* Gets the value of the specified binding.
*
* @param name the name of the binding
* @return the value of the binding
*/
@Nullable
Object getValue(@Nonnull final String name) {
return this.values.get(name);
}
/**
* Sets the runtime exception for the current script.
*
* @param t the exception
*/
void setException(@Nonnull final Throwable t) {
this.exception.set(t);
}
/**
* Sets the result of the current script.
*
* @param result the result
*/
void setResult(@Nullable final Object result) {
this.exception.set(null);
this.result.set(result);
}
/**
* Checks the output stream for a compile error.
*
* @throws ScriptException if a compile error is found
*/
private void checkCompileError() throws ScriptException {
byte[] outBytes = this.out.toByteArray();
// Look for label
int labelIndex = ArrayUtils.indexOf(outBytes, 0, outBytes.length, LABEL, 0, LABEL.length, 0);
if (labelIndex == -1) {
return;
}
// Look for start and end of message
int lineIndex = labelIndex + LABEL.length + 1;
int msgStart = ArrayUtils.indexOf(outBytes, 0, outBytes.length, SEPARATOR, 0, SEPARATOR.length, lineIndex) + 2;
int msgEnd = ArrayUtils.indexOf(outBytes, 0, outBytes.length, END_LINE, 0, END_LINE.length, msgStart);
// Throw exception
int line;
String message;
try {
line = Integer.parseInt(new String(outBytes, lineIndex, msgStart - lineIndex - 2));
message = new String(outBytes, msgStart, msgEnd - msgStart, "UTF-8");
} catch (UnsupportedEncodingException e) {
throw new IllegalStateException(e);
}
throw new ScriptException(message, "<console>", line);
}
/**
* Checks for a runtime exception.
*
* @throws ScriptException if an exception is found
*/
private void checkRuntimeError() throws ScriptException {
Throwable exception = this.exception.get();
if (exception != null) {
Throwables.propagateIfPossible(exception, ScriptException.class);
if (exception instanceof Exception) {
throw new ScriptException((Exception) exception);
} else {
throw new ScriptException(exception.getMessage());
}
}
}
}