/* * Copyright © 2016 Cask Data, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package co.cask.cdap.app.runtime.spark; import co.cask.cdap.common.conf.CConfiguration; import co.cask.cdap.common.internal.guava.ClassPath; import co.cask.cdap.common.lang.ClassLoaders; import co.cask.cdap.common.lang.ClassPathResources; import co.cask.cdap.common.lang.FilterClassLoader; import co.cask.cdap.common.lang.ProgramClassLoader; import co.cask.cdap.common.lang.WeakReferenceDelegatorClassLoader; import co.cask.cdap.common.lang.jar.BundleJarUtil; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Predicate; import com.google.common.collect.Iterables; import com.google.common.io.OutputSupplier; import com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.hadoop.conf.Configuration; import org.apache.spark.SparkConf; import org.apache.spark.streaming.DStreamGraph; import org.apache.spark.streaming.StreamingContext; import org.apache.twill.common.Cancellable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; import scala.collection.parallel.TaskSupport; import scala.collection.parallel.ThreadPoolTaskSupport; import scala.collection.parallel.mutable.ParArray; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.util.Collections; import java.util.Properties; import java.util.Set; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.zip.ZipEntry; import java.util.zip.ZipOutputStream; /** * Util class for common functions needed for Spark implementation. */ public final class SparkRuntimeUtils { private static final Logger LOG = LoggerFactory.getLogger(SparkRuntimeUtils.class); // ClassLoader filter @VisibleForTesting public static final FilterClassLoader.Filter SPARK_PROGRAM_CLASS_LOADER_FILTER = new FilterClassLoader.Filter() { final FilterClassLoader.Filter defaultFilter = FilterClassLoader.defaultFilter(); volatile Set<ClassPath.ResourceInfo> sparkStreamingResources; @Override public boolean acceptResource(final String resource) { // All Spark API, Spark, Scala and Akka classes should come from parent. if (resource.startsWith("co/cask/cdap/api/spark/")) { return true; } if (resource.startsWith("scala/")) { return true; } if (resource.startsWith("akka/")) { return true; } if (resource.startsWith("org/apache/spark/")) { // Only allows the core Spark Streaming classes, but not any streaming extensions (like Kafka). if (resource.startsWith("org/apache/spark/streaming")) { return Iterables.any(getSparkStreamingResources(), new Predicate<ClassPath.ResourceInfo>() { @Override public boolean apply(ClassPath.ResourceInfo input) { return input.getResourceName().equals(resource); } }); } return true; } return defaultFilter.acceptResource(resource); } @Override public boolean acceptPackage(final String packageName) { if (packageName.equals("co.cask.cdap.api.spark") || packageName.startsWith("co.cask.cdap.api.spark.")) { return true; } if (packageName.equals("scala") || packageName.startsWith("scala.")) { return true; } if (packageName.equals("akka") || packageName.startsWith("akka.")) { return true; } if (packageName.equals("org.apache.spark") || packageName.startsWith("org.apache.spark.")) { // Only allows the core Spark Streaming classes, but not any streaming extensions (like Kafka). if (packageName.equals("org.apache.spark.streaming") || packageName.startsWith("org.apache.spark.streaming.")) { return Iterables.any( Iterables.filter(getSparkStreamingResources(), ClassPath.ClassInfo.class), new Predicate<ClassPath.ClassInfo>() { @Override public boolean apply(ClassPath.ClassInfo input) { return input.getPackageName().equals(packageName); } }); } return true; } return defaultFilter.acceptResource(packageName); } /** * Gets the set of resources information that are from the Spark Streaming Core. It excludes any * Spark streaming extensions, such as Kafka or Flume. They need to be excluded since they are not * part of Spark distribution and it should be loaded from the user program ClassLoader. This filtering * is needed for unit-testing because in unit-test, those extension classes are loadable from the system * classloader, causing same classes being loaded through different classloader. */ private Set<ClassPath.ResourceInfo> getSparkStreamingResources() { if (sparkStreamingResources != null) { return sparkStreamingResources; } synchronized (this) { if (sparkStreamingResources != null) { return sparkStreamingResources; } try { sparkStreamingResources = ClassPathResources.getClassPathResources(getClass().getClassLoader(), StreamingContext.class); } catch (IOException e) { LOG.warn("Failed to find resources for Spark StreamingContext.", e); sparkStreamingResources = Collections.emptySet(); } return sparkStreamingResources; } } }; /** * Creates a {@link ProgramClassLoader} that have Spark classes visible. */ public static ProgramClassLoader createProgramClassLoader(CConfiguration cConf, File dir, ClassLoader unfilteredClassLoader) { ClassLoader parent = new FilterClassLoader(unfilteredClassLoader, SPARK_PROGRAM_CLASS_LOADER_FILTER); return new ProgramClassLoader(cConf, dir, parent); } /** * Creates a zip file which contains a serialized {@link Properties} with a given zip entry name, together with * all files under the given directory. This is called from Client.createConfArchive() as a workaround for the * SPARK-13441 bug. * * @param sparkConf the {@link SparkConf} to save * @param propertiesEntryName name of the zip entry for the properties * @param confDirPath directory to scan for files to include in the zip file * @param outputZipPath output file * @return the zip file */ public static File createConfArchive(SparkConf sparkConf, final String propertiesEntryName, String confDirPath, String outputZipPath) { final Properties properties = new Properties(); for (Tuple2<String, String> tuple : sparkConf.getAll()) { properties.put(tuple._1(), tuple._2()); } try { File confDir = new File(confDirPath); final File zipFile = new File(outputZipPath); BundleJarUtil.createArchive(confDir, new OutputSupplier<ZipOutputStream>() { @Override public ZipOutputStream getOutput() throws IOException { ZipOutputStream zipOutput = new ZipOutputStream(new FileOutputStream(zipFile)); zipOutput.putNextEntry(new ZipEntry(propertiesEntryName)); properties.store(zipOutput, "Spark configuration."); zipOutput.closeEntry(); return zipOutput; } }); LOG.debug("Spark config archive created at {} from {}", zipFile, confDir); return zipFile; } catch (IOException e) { throw new RuntimeException(e); } } /** * Sets the context ClassLoader to the given {@link SparkClassLoader}. It will also set the * ClassLoader for the {@link Configuration} contained inside the {@link SparkClassLoader}. * * @return a {@link Cancellable} to reset the classloader to the one prior to the call */ public static Cancellable setContextClassLoader(final SparkClassLoader sparkClassLoader) { final Configuration hConf = sparkClassLoader.getRuntimeContext().getConfiguration(); final ClassLoader oldConfClassLoader = hConf.getClassLoader(); // Always wrap it with WeakReference to avoid ClassLoader leakage from Spark. ClassLoader classLoader = new WeakReferenceDelegatorClassLoader(sparkClassLoader); hConf.setClassLoader(classLoader); final ClassLoader oldClassLoader = ClassLoaders.setContextClassLoader(classLoader); return new Cancellable() { @Override public void cancel() { hConf.setClassLoader(oldConfClassLoader); ClassLoaders.setContextClassLoader(oldClassLoader); // Do not remove the next line. // This is necessary to keep a strong reference to the SparkClassLoader so that it won't get GC until this // cancel() is called LOG.trace("Reset context ClassLoader. The SparkClassLoader is: {}", sparkClassLoader); } }; } /** * Sets the {@link TaskSupport} for the given Scala {@link ParArray} to {@link ThreadPoolTaskSupport}. * This method is mainly used by {@link SparkRunnerClassLoader} to set the {@link TaskSupport} for the * parallel array used inside the {@link DStreamGraph} class in spark to avoid thread leakage after the * Spark program execution finished. */ @SuppressWarnings("unused") public static <T> ParArray<T> setTaskSupport(ParArray<T> parArray) { ThreadPoolExecutor executor = new ThreadPoolExecutor(0, Integer.MAX_VALUE, 1, TimeUnit.SECONDS, new SynchronousQueue<Runnable>(), new ThreadFactoryBuilder() .setNameFormat("task-support-%d").build()); executor.allowCoreThreadTimeOut(true); parArray.tasksupport_$eq(new ThreadPoolTaskSupport(executor)); return parArray; } private SparkRuntimeUtils() { // private } }