/* * Copyright 2016 ThoughtWorks, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.thoughtworks.go.util; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.log4j.Logger; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.URL; import java.net.URLClassLoader; import java.util.ArrayList; import java.util.List; import java.util.UUID; import java.util.jar.JarEntry; import java.util.jar.JarInputStream; /** * Loads the classes from the given jars. */ public class NestedJarClassLoader extends ClassLoader { private static final Logger LOGGER = Logger.getLogger(NestedJarClassLoader.class); private final ClassLoader jarClassLoader; private final ClassLoader parentClassLoader; private final String[] excludes; private final File jarDir; private static final File TEMP_DIR = new File("data/njcl"); static { Runtime.getRuntime().addShutdownHook(new Thread() { @Override public void run() { FileUtils.deleteQuietly(TEMP_DIR); } }); } public NestedJarClassLoader(URL jarURL, String... excludes) { this(jarURL, NestedJarClassLoader.class.getClassLoader(), excludes); } NestedJarClassLoader(URL jarURL, ClassLoader parentClassLoader, String... excludes) { super(null); this.jarDir = new File(TEMP_DIR, UUID.randomUUID().toString()); this.parentClassLoader = parentClassLoader; this.jarClassLoader = createLoaderForJar(jarURL); this.excludes = excludes; Runtime.getRuntime().addShutdownHook(new Thread() { @Override public void run() { FileUtils.deleteQuietly(jarDir); } }); } private ClassLoader createLoaderForJar(URL jarURL) { if (LOGGER.isDebugEnabled()) { LOGGER.debug("Creating Loader For jar: " + jarURL); } ClassLoader jarLoader = new URLClassLoader(enumerateJar(jarURL), this); if (jarLoader == null) { LOGGER.warn("No jar found with url: " + jarURL); } return jarLoader; } private URL[] enumerateJar(URL urlOfJar) { if (LOGGER.isDebugEnabled()) { LOGGER.debug("Enumerating jar: " + urlOfJar); } List<URL> urls = new ArrayList<>(); urls.add(urlOfJar); try { JarInputStream jarStream = new JarInputStream(urlOfJar.openStream()); JarEntry entry; while ((entry = jarStream.getNextJarEntry()) != null) { if (!entry.isDirectory() && entry.getName().endsWith(".jar")) { urls.add(expandJarAndReturnURL(jarStream, entry)); } } } catch (IOException e) { LOGGER.error("Failed to enumerate jar " + urlOfJar, e); } return urls.toArray(new URL[0]); } private URL expandJarAndReturnURL(JarInputStream jarStream, JarEntry entry) throws IOException { File nestedJarFile = new File(jarDir, entry.getName()); nestedJarFile.getParentFile().mkdirs(); try (FileOutputStream out = new FileOutputStream(nestedJarFile)) { IOUtils.copy(jarStream, out); } LOGGER.info(String.format("Exploded Entry %s from to %s", entry.getName(), nestedJarFile)); return nestedJarFile.toURI().toURL(); } @Override public Class<?> loadClass(String name) throws ClassNotFoundException { if (existsInTfsJar(name)) { return jarClassLoader.loadClass(name); } return parentClassLoader.loadClass(name); } @Override protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException { if (existsInTfsJar(name)) { throw new ClassNotFoundException(name); } return invokeParentClassloader(name, resolve); } private Class<?> invokeParentClassloader(String name, boolean resolve) throws ClassNotFoundException { if (LOGGER.isDebugEnabled()) { LOGGER.debug(String.format("Invoking parent classloader for %s with resolve %s", name, resolve)); } try { Method loadClass = findNonPublicMethod("loadClass", parentClassLoader.getClass(), String.class, boolean.class); return (Class<?>) loadClass.invoke(parentClassLoader, name, resolve); } catch (InvocationTargetException e) { handleClassNotFound(e); throw new RuntimeException("Failed to invoke parent classloader", e); } catch (IllegalAccessException e) { throw new RuntimeException("Failed to invoke parent classloader", e); } } private Method findNonPublicMethod(String name, Class klass, Class... args) { try { Method method = klass.getDeclaredMethod(name, args); method.setAccessible(true); return method; } catch (NoSuchMethodException e) { return findNonPublicMethod(name, klass.getSuperclass(), args); } } private void handleClassNotFound(InvocationTargetException e) throws ClassNotFoundException { if (e.getCause() instanceof ClassNotFoundException) { throw (ClassNotFoundException) e.getCause(); } } private boolean existsInTfsJar(String name) { if (jarClassLoader == null) { return false; } String classAsResourceName = name.replace('.', '/') + ".class"; if (isExcluded(classAsResourceName)) { return false; } URL url = jarClassLoader.getResource(classAsResourceName); if (LOGGER.isDebugEnabled()) { LOGGER.debug(String.format("Loading %s from jar returned %s for url: %s ", name, url != null, url)); } return url != null; } @Override public URL getResource(String name) { if (isExcluded(name)) { return parentClassLoader.getResource(name); } return null; } private boolean isExcluded(String name) { for (String excluded : excludes) { if (name.contains(excluded)) { return true; } } return false; } }