/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.zeppelin.spark.dep; import java.io.File; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.URL; import java.util.Arrays; import java.util.Collection; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import org.apache.commons.lang.StringUtils; import org.apache.spark.SparkContext; import org.apache.zeppelin.dep.AbstractDependencyResolver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.sonatype.aether.artifact.Artifact; import org.sonatype.aether.collection.CollectRequest; import org.sonatype.aether.graph.Dependency; import org.sonatype.aether.graph.DependencyFilter; import org.sonatype.aether.repository.RemoteRepository; import org.sonatype.aether.resolution.ArtifactResult; import org.sonatype.aether.resolution.DependencyRequest; import org.sonatype.aether.util.artifact.DefaultArtifact; import org.sonatype.aether.util.artifact.JavaScopes; import org.sonatype.aether.util.filter.DependencyFilterUtils; import org.sonatype.aether.util.filter.PatternExclusionsDependencyFilter; import scala.Some; import scala.collection.IndexedSeq; import scala.reflect.io.AbstractFile; import scala.tools.nsc.Global; import scala.tools.nsc.backend.JavaPlatform; import scala.tools.nsc.util.ClassPath; import scala.tools.nsc.util.MergedClassPath; /** * Deps resolver. * Add new dependencies from mvn repo (at runtime) to Spark interpreter group. */ public class SparkDependencyResolver extends AbstractDependencyResolver { Logger logger = LoggerFactory.getLogger(SparkDependencyResolver.class); private Global global; private ClassLoader runtimeClassLoader; private SparkContext sc; private final String[] exclusions = new String[] {"org.scala-lang:scala-library", "org.scala-lang:scala-compiler", "org.scala-lang:scala-reflect", "org.scala-lang:scalap", "org.apache.zeppelin:zeppelin-zengine", "org.apache.zeppelin:zeppelin-spark", "org.apache.zeppelin:zeppelin-server"}; public SparkDependencyResolver(Global global, ClassLoader runtimeClassLoader, SparkContext sc, String localRepoPath, String additionalRemoteRepository) { super(localRepoPath); this.global = global; this.runtimeClassLoader = runtimeClassLoader; this.sc = sc; addRepoFromProperty(additionalRemoteRepository); } private void addRepoFromProperty(String listOfRepo) { if (listOfRepo != null) { String[] repos = listOfRepo.split(";"); for (String repo : repos) { String[] parts = repo.split(","); if (parts.length == 3) { String id = parts[0].trim(); String url = parts[1].trim(); boolean isSnapshot = Boolean.parseBoolean(parts[2].trim()); if (id.length() > 1 && url.length() > 1) { addRepo(id, url, isSnapshot); } } } } } private void updateCompilerClassPath(URL[] urls) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { JavaPlatform platform = (JavaPlatform) global.platform(); MergedClassPath<AbstractFile> newClassPath = mergeUrlsIntoClassPath(platform, urls); Method[] methods = platform.getClass().getMethods(); for (Method m : methods) { if (m.getName().endsWith("currentClassPath_$eq")) { m.invoke(platform, new Some(newClassPath)); break; } } // NOTE: Must use reflection until this is exposed/fixed upstream in Scala List<String> classPaths = new LinkedList<>(); for (URL url : urls) { classPaths.add(url.getPath()); } // Reload all jars specified into our compiler global.invalidateClassPathEntries(scala.collection.JavaConversions.asScalaBuffer(classPaths) .toList()); } // Until spark 1.1.x // check https://github.com/apache/spark/commit/191d7cf2a655d032f160b9fa181730364681d0e7 private void updateRuntimeClassPath_1_x(URL[] urls) throws SecurityException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, NoSuchMethodException { Method addURL; addURL = runtimeClassLoader.getClass().getDeclaredMethod("addURL", new Class[] {URL.class}); addURL.setAccessible(true); for (URL url : urls) { addURL.invoke(runtimeClassLoader, url); } } private void updateRuntimeClassPath_2_x(URL[] urls) throws SecurityException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, NoSuchMethodException { Method addURL; addURL = runtimeClassLoader.getClass().getDeclaredMethod("addNewUrl", new Class[] {URL.class}); addURL.setAccessible(true); for (URL url : urls) { addURL.invoke(runtimeClassLoader, url); } } private MergedClassPath<AbstractFile> mergeUrlsIntoClassPath(JavaPlatform platform, URL[] urls) { IndexedSeq<ClassPath<AbstractFile>> entries = ((MergedClassPath<AbstractFile>) platform.classPath()).entries(); List<ClassPath<AbstractFile>> cp = new LinkedList<>(); for (int i = 0; i < entries.size(); i++) { cp.add(entries.apply(i)); } for (URL url : urls) { AbstractFile file; if ("file".equals(url.getProtocol())) { File f = new File(url.getPath()); if (f.isDirectory()) { file = AbstractFile.getDirectory(scala.reflect.io.File.jfile2path(f)); } else { file = AbstractFile.getFile(scala.reflect.io.File.jfile2path(f)); } } else { file = AbstractFile.getURL(url); } ClassPath<AbstractFile> newcp = platform.classPath().context().newClassPath(file); // distinct if (cp.contains(newcp) == false) { cp.add(newcp); } } return new MergedClassPath(scala.collection.JavaConversions.asScalaBuffer(cp).toIndexedSeq(), platform.classPath().context()); } public List<String> load(String artifact, boolean addSparkContext) throws Exception { return load(artifact, new LinkedList<String>(), addSparkContext); } public List<String> load(String artifact, Collection<String> excludes, boolean addSparkContext) throws Exception { if (StringUtils.isBlank(artifact)) { // Should throw here throw new RuntimeException("Invalid artifact to load"); } // <groupId>:<artifactId>[:<extension>[:<classifier>]]:<version> int numSplits = artifact.split(":").length; if (numSplits >= 3 && numSplits <= 6) { return loadFromMvn(artifact, excludes, addSparkContext); } else { loadFromFs(artifact, addSparkContext); LinkedList<String> libs = new LinkedList<>(); libs.add(artifact); return libs; } } private void loadFromFs(String artifact, boolean addSparkContext) throws Exception { File jarFile = new File(artifact); global.new Run(); if (sc.version().startsWith("1.1")) { updateRuntimeClassPath_1_x(new URL[] {jarFile.toURI().toURL()}); } else { updateRuntimeClassPath_2_x(new URL[] {jarFile.toURI().toURL()}); } if (addSparkContext) { sc.addJar(jarFile.getAbsolutePath()); } } private List<String> loadFromMvn(String artifact, Collection<String> excludes, boolean addSparkContext) throws Exception { List<String> loadedLibs = new LinkedList<>(); Collection<String> allExclusions = new LinkedList<>(); allExclusions.addAll(excludes); allExclusions.addAll(Arrays.asList(exclusions)); List<ArtifactResult> listOfArtifact; listOfArtifact = getArtifactsWithDep(artifact, allExclusions); Iterator<ArtifactResult> it = listOfArtifact.iterator(); while (it.hasNext()) { Artifact a = it.next().getArtifact(); String gav = a.getGroupId() + ":" + a.getArtifactId() + ":" + a.getVersion(); for (String exclude : allExclusions) { if (gav.startsWith(exclude)) { it.remove(); break; } } } List<URL> newClassPathList = new LinkedList<>(); List<File> files = new LinkedList<>(); for (ArtifactResult artifactResult : listOfArtifact) { logger.info("Load " + artifactResult.getArtifact().getGroupId() + ":" + artifactResult.getArtifact().getArtifactId() + ":" + artifactResult.getArtifact().getVersion()); newClassPathList.add(artifactResult.getArtifact().getFile().toURI().toURL()); files.add(artifactResult.getArtifact().getFile()); loadedLibs.add(artifactResult.getArtifact().getGroupId() + ":" + artifactResult.getArtifact().getArtifactId() + ":" + artifactResult.getArtifact().getVersion()); } global.new Run(); if (sc.version().startsWith("1.1")) { updateRuntimeClassPath_1_x(newClassPathList.toArray(new URL[0])); } else { updateRuntimeClassPath_2_x(newClassPathList.toArray(new URL[0])); } updateCompilerClassPath(newClassPathList.toArray(new URL[0])); if (addSparkContext) { for (File f : files) { sc.addJar(f.getAbsolutePath()); } } return loadedLibs; } /** * @param dependency * @param excludes list of pattern can either be of the form groupId:artifactId * @return * @throws Exception */ @Override public List<ArtifactResult> getArtifactsWithDep(String dependency, Collection<String> excludes) throws Exception { Artifact artifact = new DefaultArtifact(inferScalaVersion(dependency)); DependencyFilter classpathFilter = DependencyFilterUtils.classpathFilter(JavaScopes.COMPILE); PatternExclusionsDependencyFilter exclusionFilter = new PatternExclusionsDependencyFilter(inferScalaVersion(excludes)); CollectRequest collectRequest = new CollectRequest(); collectRequest.setRoot(new Dependency(artifact, JavaScopes.COMPILE)); synchronized (repos) { for (RemoteRepository repo : repos) { collectRequest.addRepository(repo); } } DependencyRequest dependencyRequest = new DependencyRequest(collectRequest, DependencyFilterUtils.andFilter(exclusionFilter, classpathFilter)); return system.resolveDependencies(session, dependencyRequest).getArtifactResults(); } public static Collection<String> inferScalaVersion(Collection<String> artifact) { List<String> list = new LinkedList<>(); for (String a : artifact) { list.add(inferScalaVersion(a)); } return list; } public static String inferScalaVersion(String artifact) { int pos = artifact.indexOf(":"); if (pos < 0 || pos + 2 >= artifact.length()) { // failed to infer return artifact; } if (':' == artifact.charAt(pos + 1)) { String restOfthem = ""; String versionSep = ":"; String groupId = artifact.substring(0, pos); int nextPos = artifact.indexOf(":", pos + 2); if (nextPos < 0) { if (artifact.charAt(artifact.length() - 1) == '*') { nextPos = artifact.length() - 1; versionSep = ""; restOfthem = "*"; } else { versionSep = ""; nextPos = artifact.length(); } } String artifactId = artifact.substring(pos + 2, nextPos); if (nextPos < artifact.length()) { if (!restOfthem.equals("*")) { restOfthem = artifact.substring(nextPos + 1); } } String [] version = scala.util.Properties.versionNumberString().split("[.]"); String scalaVersion = version[0] + "." + version[1]; return groupId + ":" + artifactId + "_" + scalaVersion + versionSep + restOfthem; } else { return artifact; } } }