/*
* Copyright 2014 TWO SIGMA OPEN SOURCE, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.twosigma.beaker.clojure;
import clojure.lang.RT;
import clojure.lang.Var;
import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import com.twosigma.beaker.autocomplete.AutocompleteResult;
import com.twosigma.beaker.evaluator.Evaluator;
import com.twosigma.beaker.evaluator.InternalVariable;
import com.twosigma.beaker.jvm.classloader.DynamicClassLoaderSimple;
import com.twosigma.beaker.jvm.object.SimpleEvaluationObject;
import com.twosigma.beaker.jvm.threads.BeakerCellExecutor;
import com.twosigma.jupyter.KernelParameters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringReader;
import java.io.StringWriter;
import java.lang.reflect.InvocationTargetException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Semaphore;
import static com.twosigma.beaker.jupyter.comm.KernelControlSetShellHandler.CLASSPATH;
import static com.twosigma.beaker.jupyter.comm.KernelControlSetShellHandler.IMPORTS;
public class ClojureEvaluator implements Evaluator {
private final static Logger logger = LoggerFactory.getLogger(ClojureEvaluator.class.getName());
private final String shellId;
private final String sessionId;
private List<String> classPath;
private List<String> imports;
private List<String> requirements;
private boolean exit;
private BeakerCellExecutor executor;
private workerThread myWorker;
private String outDir;
private String currenClojureNS;
private DynamicClassLoaderSimple loader;
private class jobDescriptor {
String codeToBeExecuted;
SimpleEvaluationObject outputObject;
jobDescriptor(String c, SimpleEvaluationObject o) {
codeToBeExecuted = c;
outputObject = o;
}
}
private static final String beaker_clojure_ns = "beaker_clojure_shell";
private Var clojureLoadString = null;
private final Semaphore syncObject = new Semaphore(0, true);
private final ConcurrentLinkedQueue<jobDescriptor> jobQueue = new ConcurrentLinkedQueue<jobDescriptor>();
private String initScriptSource()
throws IOException {
URL url = this.getClass().getClassLoader().getResource("init_clojure_script.txt");
return Resources.toString(url, Charsets.UTF_8);
}
public ClojureEvaluator(String id, String sId) {
shellId = id;
sessionId = sId;
classPath = new ArrayList<String>();
imports = new ArrayList<String>();
requirements = new ArrayList<>();
outDir = Evaluator.createJupyterTempFolder().toString();
init();
}
private void init() {
loader = new DynamicClassLoaderSimple(ClassLoader.getSystemClassLoader());
loader.addJars(classPath);
loader.addDynamicDir(outDir);
String loadFunctionPrefix = "run_str";
currenClojureNS = String.format("%1$s_%2$s", beaker_clojure_ns, shellId);
try {
String clojureInitScript = String.format(initScriptSource(), beaker_clojure_ns, shellId,
loadFunctionPrefix);
clojureLoadString = RT.var(String.format("%1$s_%2$s", beaker_clojure_ns, shellId),
String.format("%1$s_%2$s", loadFunctionPrefix, shellId));
clojure.lang.Compiler.load(new StringReader(clojureInitScript));
} catch (IOException e) {
logger.error(e.getMessage());
}
exit = false;
executor = new BeakerCellExecutor("clojure");
}
@Override
public void startWorker() {
myWorker = new workerThread();
myWorker.start();
}
public String getShellId() {
return shellId;
}
public void killAllThreads() {
executor.killAllThreads();
}
public void cancelExecution() {
executor.cancelExecution();
}
public void resetEnvironment() {
executor.killAllThreads();
loader = new DynamicClassLoaderSimple(ClassLoader.getSystemClassLoader());
loader.addJars(classPath);
loader.addDynamicDir(outDir);
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(loader);
for (String s : imports) {
if (s != null & !s.isEmpty())
try {
loader.loadClass(s);
clojureLoadString.invoke(String.format("(import '%s)", s));
} catch (Exception e) {
logger.error(e.getMessage());
}
}
for (String s : requirements) {
if (s != null && !s.isEmpty())
try {
clojureLoadString.invoke(String.format("(require '%s)", s));
} catch (Exception e) {
logger.error(e.getMessage());
}
}
Thread.currentThread().setContextClassLoader(oldLoader);
syncObject.release();
}
public void exit() {
exit = true;
cancelExecution();
syncObject.release();
}
public void evaluate(SimpleEvaluationObject seo, String code) {
// send job to thread
jobQueue.add(new jobDescriptor(code, seo));
syncObject.release();
}
private class workerThread extends Thread {
public workerThread() {
super("clojure worker");
}
/*
* This thread performs all the evaluation
*/
public void run() {
jobDescriptor j = null;
while (!exit) {
try {
// wait for work
syncObject.acquire();
// get next job descriptor
j = jobQueue.poll();
if (j == null)
continue;
j.outputObject.started();
if (!executor.executeTask(new MyRunnable(j.codeToBeExecuted, j.outputObject))) {
j.outputObject.error("... cancelled!");
}
} catch (Throwable e) {
logger.error(e.getMessage());
} finally {
if (j != null && j.outputObject != null) {
j.outputObject.executeCodeCallback();
}
}
}
}
private class MyRunnable implements Runnable {
private final String theCode;
private final SimpleEvaluationObject theOutput;
private MyRunnable(String code, SimpleEvaluationObject out) {
theCode = code;
theOutput = out;
}
@Override
public void run() {
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(loader);
theOutput.setOutputHandler();
Object result;
try {
InternalVariable.setValue(theOutput);
Object o = clojureLoadString.invoke(theCode);
try {
//workaround, checking of corrupted clojure objects
if (null != o) {
o.hashCode();
}
theOutput.finished(o);
} catch (Exception e) {
theOutput.error("Object: " + o.getClass() + ", value cannot be displayed due to following error: " + e.getMessage());
}
} catch (Throwable e) {
if (e instanceof InterruptedException || e instanceof InvocationTargetException || e instanceof ThreadDeath) {
theOutput.error("... cancelled!");
} else {
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
if (null != e.getCause()) {
e.getCause().printStackTrace(pw);
} else {
e.printStackTrace(pw);
}
theOutput.error(sw.toString());
}
}
theOutput.setOutputHandler();
Thread.currentThread().setContextClassLoader(oldLoader);
}
}
}
@Override
public void setShellOptions(final KernelParameters kernelParameters) throws IOException {
Map<String, Object> params = kernelParameters.getParams();
Collection<String> listOfClassPath = (Collection<String>) params.get(CLASSPATH);
Collection<String> listOfImports = (Collection<String>) params.get(IMPORTS);
Map<String, String> env = System.getenv();
if (listOfClassPath == null || listOfClassPath.isEmpty()){
classPath = new ArrayList<>();
} else {
for (String line : listOfClassPath) {
if (!line.trim().isEmpty()) {
classPath.add(line);
}
}
}
if (listOfImports == null || listOfImports.isEmpty()){
imports = new ArrayList<>();
} else {
for (String line : listOfImports) {
if (!line.trim().isEmpty()) {
imports.add(line);
}
}
}
resetEnvironment();
}
public AutocompleteResult autocomplete(String code, int caretPosition) {
int i = caretPosition;
while (i > 0) {
char c = code.charAt(i - 1);
if (!Character.isUnicodeIdentifierStart(c) || "[]{}()/\\".indexOf(c) >= 0) {
break;
} else {
i--;
}
}
String _code = code.substring(i, caretPosition);
String apropos = "(repl_%1$s/apropos \"%2$s\")";
Object o = clojureLoadString.invoke(String.format(apropos, shellId, _code));
List<String> result = new ArrayList<String>();
for (Object s : ((Collection) o)) {
String whole = s.toString();
int d = whole.indexOf('/');
if (d > 0) {
String woNS = whole.substring(d + 1);
String ns = whole.substring(0, d);
result.add(woNS);
if (!currenClojureNS.equals(ns) && !"clojure.core".equals(ns)) result.add(whole);
} else {
result.add(whole);
}
}
return new AutocompleteResult(result, caretPosition);
}
}