package com.thinkbiganalytics.spark.repl; /*- * #%L * thinkbig-commons-spark-repl * %% * Copyright (C) 2017 ThinkBig Analytics * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import com.google.common.base.Joiner; import com.google.common.base.Throwables; import com.thinkbiganalytics.spark.SparkInterpreterBuilder; import org.apache.commons.io.IOUtils; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.ComponentScan; import org.springframework.stereotype.Component; import java.io.IOException; import java.io.InputStream; import java.net.URLClassLoader; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.concurrent.Callable; import java.util.regex.Pattern; import javax.annotation.Nonnull; import javax.annotation.Nullable; import javax.script.ScriptException; import scala.collection.JavaConversions; import scala.tools.nsc.Settings; import scala.tools.nsc.interpreter.IMain; import scala.tools.nsc.interpreter.Results; /** * Evaluates Scala scripts using the Spark REPL interface. */ @Component @ComponentScan("com.thinkbiganalytics.spark") public class SparkScriptEngine extends ScriptEngine { private static final Logger log = LoggerFactory.getLogger(SparkScriptEngine.class); /** * Matches a multi-line comment in Scala */ private static final Pattern COMMENT = Pattern.compile("/\\*.*\\*/"); /** * Matches continuation lines in Scala */ private static final Pattern LINE_CONTINUATION = Pattern.compile("^\\s*\\."); /** * Spark configuration */ @Autowired private SparkConf conf; /** * List of patterns to deny in scripts */ @Nullable private List<Pattern> denyPatterns; /** * Spark REPL interface */ @Nullable private IMain interpreter; @Autowired private SparkInterpreterBuilder builder; @Nonnull @Override public ClassLoader getClassLoader() { // Get current context class loader final Thread currentThread = Thread.currentThread(); final ClassLoader contextClassLoader = currentThread.getContextClassLoader(); // Get interpreter class loader from context getInterpreter().setContextClassLoader(); final ClassLoader interpreterClassLoader = currentThread.getContextClassLoader(); // Reset context currentThread.setContextClassLoader(contextClassLoader); return interpreterClassLoader; } @Nonnull @Override protected SparkContext createSparkContext() { // Allow interpreter to modify Thread context for Spark getInterpreter().setContextClassLoader(); // The SparkContext ClassLoader is needed during initialization (only for YARN master) return executeWithSparkClassLoader(new Callable<SparkContext>() { @Override public SparkContext call() throws Exception { log.info("Creating spark context with spark conf {}", conf); return new SparkContext(conf); } }); } @Override protected void execute(@Nonnull final String script) throws ScriptException { log.debug("Executing script:\n{}", script); // Convert script to single line (for checking security violations) final StringBuilder safeScriptBuilder = new StringBuilder(script.length()); for (final String line : script.split("\n")) { if (!LINE_CONTINUATION.matcher(line).find()) { safeScriptBuilder.append(';'); } safeScriptBuilder.append(line); } final String safeScript = COMMENT.matcher(safeScriptBuilder.toString()).replaceAll(""); // Check for security violations for (final Pattern pattern : getDenyPatterns()) { if (pattern.matcher(safeScript).find()) { log.error("Not executing script that matches deny pattern: {}", pattern); throw new ScriptException("Script not executed due to security policy."); } } // Execute script try { getInterpreter().interpret(safeScript); } catch (final AssertionError e) { log.warn("Caught assertion error when executing script. Retrying...", e); reset(); getInterpreter().interpret(safeScript); } } @Override protected void reset() { super.reset(); // Clear the interpreter if (interpreter != null) { interpreter.close(); interpreter = null; } } /** * Executes the specified callable after replacing the current context class loader. * * <p>This is a work-around to avoid {@link ClassCastException} issues caused by conflicts between Hadoop and Kylo Spark Shell. Spark uses the context class loader when loading Hadoop components * for running Spark on YARN. When both Hadoop and Kylo Spark Shell provide the same class then both classes are loaded when creating a {@link SparkContext}. The fix is to set the context class * loader to the same class loader that was used to load the {@link SparkContext} class.</p> * * @param callable the function to be executed * @param <T> the return type * @return the return value */ private <T> T executeWithSparkClassLoader(@Nonnull final Callable<T> callable) { // Set context class loader final Thread currentThread = Thread.currentThread(); final ClassLoader contextClassLoader = currentThread.getContextClassLoader(); final ClassLoader sparkClassLoader = new ForwardingClassLoader(SparkContext.class.getClassLoader(), contextClassLoader); currentThread.setContextClassLoader(sparkClassLoader); // Execute callable try { return callable.call(); } catch (final Exception e) { throw Throwables.propagate(e); } finally { // Reset context class loader currentThread.setContextClassLoader(contextClassLoader); } } /** * Gets the list of patterns that should prevent a script from executing. * * @return the deny patterns list * @throws IllegalStateException if the spark-deny-patterns.conf file cannot be found */ @Nonnull private List<Pattern> getDenyPatterns() { if (denyPatterns == null) { // Load custom or default deny patterns String resourceName = "spark-deny-patterns.conf"; InputStream resourceStream = getClass().getResourceAsStream("/" + resourceName); if (resourceStream == null) { resourceName = "spark-deny-patterns.default.conf"; resourceStream = getClass().getResourceAsStream(resourceName); } // Parse lines final List<String> denyPatternLines; if (resourceStream != null) { try { denyPatternLines = IOUtils.readLines(resourceStream, "UTF-8"); log.info("Loaded Spark deny patterns from {}.", resourceName); } catch (final IOException e) { throw new IllegalStateException("Unable to load " + resourceName, e); } } else { log.info("Missing default Spark deny patterns."); denyPatternLines = Collections.emptyList(); } // Compile patterns denyPatterns = new ArrayList<>(); for (final String line : denyPatternLines) { final String trimLine = line.trim(); if (!line.startsWith("#") && !trimLine.isEmpty()) { denyPatterns.add(Pattern.compile(line)); } } } return denyPatterns; } /** * Gets the Spark REPL interface to be used. * * @return the interpreter */ @Nonnull private IMain getInterpreter() { if (this.interpreter == null) { // Determine engine settings final Settings settings = getSettings(); // Initialize engine final ClassLoader parentClassLoader = getClass().getClassLoader(); final SparkInterpreterBuilder b = this.builder.withSettings(settings) .withPrintWriter(getPrintWriter()) .withClassLoader(parentClassLoader); final IMain interpreter = b.newInstance(); interpreter.setContextClassLoader(); interpreter.initializeSynchronous(); // Setup environment final scala.collection.immutable.List<String> empty = JavaConversions.asScalaBuffer(new ArrayList<String>()).toList(); final Results.Result result = interpreter.bind("engine", SparkScriptEngine.class.getName(), this, empty); if (result instanceof Results.Error$) { throw new IllegalStateException("Failed to initialize interpreter"); } this.interpreter = interpreter; } return this.interpreter; } /** * Gets the settings for the interpreter. * * @return the interpreter settings */ @Nonnull private Settings getSettings() { final Settings settings = new Settings(); if (settings.classpath().isDefault()) { final String classPath = Joiner.on(':').join(((URLClassLoader) getClass().getClassLoader()).getURLs()) + ":" + System.getProperty("java.class.path"); settings.classpath().value_$eq(classPath); } return settings; } }