/* * Copyright 2012-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.boot.junit.runner.classpath; import java.io.File; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.net.URL; import java.net.URLClassLoader; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.jar.Attributes; import java.util.jar.JarFile; import org.apache.maven.repository.internal.MavenRepositorySystemUtils; import org.eclipse.aether.DefaultRepositorySystemSession; import org.eclipse.aether.RepositorySystem; import org.eclipse.aether.artifact.DefaultArtifact; import org.eclipse.aether.collection.CollectRequest; import org.eclipse.aether.connector.basic.BasicRepositoryConnectorFactory; import org.eclipse.aether.graph.Dependency; import org.eclipse.aether.impl.DefaultServiceLocator; import org.eclipse.aether.repository.LocalRepository; import org.eclipse.aether.repository.RemoteRepository; import org.eclipse.aether.resolution.ArtifactResult; import org.eclipse.aether.resolution.DependencyRequest; import org.eclipse.aether.resolution.DependencyResult; import org.eclipse.aether.spi.connector.RepositoryConnectorFactory; import org.eclipse.aether.spi.connector.transport.TransporterFactory; import org.eclipse.aether.transport.http.HttpTransporterFactory; import org.junit.runners.BlockJUnit4ClassRunner; import org.junit.runners.model.FrameworkMethod; import org.junit.runners.model.InitializationError; import org.junit.runners.model.TestClass; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.util.AntPathMatcher; import org.springframework.util.StringUtils; /** * A custom {@link BlockJUnit4ClassRunner} that runs tests using a modified class path. * Entries are excluded from the class path using {@link ClassPathExclusions} and * overridden using {@link ClassPathOverrides} on the test class. A class loader is * created with the customized class path and is used both to load the test class and as * the thread context class loader while the test is being run. * * @author Andy Wilkinson */ public class ModifiedClassPathRunner extends BlockJUnit4ClassRunner { public ModifiedClassPathRunner(Class<?> testClass) throws InitializationError { super(testClass); } @Override protected TestClass createTestClass(Class<?> testClass) { try { ClassLoader classLoader = createTestClassLoader(testClass); return new ModifiedClassPathTestClass(classLoader, testClass.getName()); } catch (Exception ex) { throw new IllegalStateException(ex); } } @Override protected Object createTest() throws Exception { ModifiedClassPathTestClass testClass = (ModifiedClassPathTestClass) getTestClass(); return testClass.doWithModifiedClassPathThreadContextClassLoader( new ModifiedClassPathTestClass.ModifiedClassPathTcclAction<Object, Exception>() { @Override public Object perform() throws Exception { return ModifiedClassPathRunner.super.createTest(); } }); } private URLClassLoader createTestClassLoader(Class<?> testClass) throws Exception { URLClassLoader classLoader = (URLClassLoader) this.getClass().getClassLoader(); return new ModifiedClassPathClassLoader( processUrls(extractUrls(classLoader), testClass), classLoader.getParent(), classLoader); } private URL[] extractUrls(URLClassLoader classLoader) throws Exception { List<URL> extractedUrls = new ArrayList<URL>(); for (URL url : classLoader.getURLs()) { if (isSurefireBooterJar(url)) { extractedUrls.addAll(extractUrlsFromManifestClassPath(url)); } else { extractedUrls.add(url); } } return extractedUrls.toArray(new URL[extractedUrls.size()]); } private boolean isSurefireBooterJar(URL url) { return url.getPath().contains("surefirebooter"); } private List<URL> extractUrlsFromManifestClassPath(URL booterJar) throws Exception { List<URL> urls = new ArrayList<URL>(); for (String entry : getClassPath(booterJar)) { urls.add(new URL(entry)); } return urls; } private String[] getClassPath(URL booterJar) throws Exception { JarFile jarFile = new JarFile(new File(booterJar.toURI())); try { return StringUtils.delimitedListToStringArray(jarFile.getManifest() .getMainAttributes().getValue(Attributes.Name.CLASS_PATH), " "); } finally { jarFile.close(); } } private URL[] processUrls(URL[] urls, Class<?> testClass) throws Exception { ClassPathEntryFilter filter = new ClassPathEntryFilter(testClass); List<URL> processedUrls = new ArrayList<URL>(); processedUrls.addAll(getAdditionalUrls(testClass)); for (URL url : urls) { if (!filter.isExcluded(url)) { processedUrls.add(url); } } return processedUrls.toArray(new URL[processedUrls.size()]); } private List<URL> getAdditionalUrls(Class<?> testClass) throws Exception { ClassPathOverrides overrides = AnnotationUtils.findAnnotation(testClass, ClassPathOverrides.class); if (overrides == null) { return Collections.emptyList(); } return resolveCoordinates(overrides.value()); } private List<URL> resolveCoordinates(String[] coordinates) throws Exception { DefaultServiceLocator serviceLocator = MavenRepositorySystemUtils .newServiceLocator(); serviceLocator.addService(RepositoryConnectorFactory.class, BasicRepositoryConnectorFactory.class); serviceLocator.addService(TransporterFactory.class, HttpTransporterFactory.class); RepositorySystem repositorySystem = serviceLocator .getService(RepositorySystem.class); DefaultRepositorySystemSession session = MavenRepositorySystemUtils.newSession(); LocalRepository localRepository = new LocalRepository( System.getProperty("user.home") + "/.m2/repository"); session.setLocalRepositoryManager( repositorySystem.newLocalRepositoryManager(session, localRepository)); CollectRequest collectRequest = new CollectRequest(null, Arrays.asList(new RemoteRepository.Builder("central", "default", "http://central.maven.org/maven2").build())); collectRequest.setDependencies(createDependencies(coordinates)); DependencyRequest dependencyRequest = new DependencyRequest(collectRequest, null); DependencyResult result = repositorySystem.resolveDependencies(session, dependencyRequest); List<URL> resolvedArtifacts = new ArrayList<URL>(); for (ArtifactResult artifact : result.getArtifactResults()) { resolvedArtifacts.add(artifact.getArtifact().getFile().toURI().toURL()); } return resolvedArtifacts; } private List<Dependency> createDependencies(String[] allCoordinates) { List<Dependency> dependencies = new ArrayList<Dependency>(); for (String coordinate : allCoordinates) { dependencies.add(new Dependency(new DefaultArtifact(coordinate), null)); } return dependencies; } /** * Filter for class path entries. */ private static final class ClassPathEntryFilter { private final List<String> exclusions; private final AntPathMatcher matcher = new AntPathMatcher(); private ClassPathEntryFilter(Class<?> testClass) throws Exception { ClassPathExclusions exclusions = AnnotationUtils.findAnnotation(testClass, ClassPathExclusions.class); this.exclusions = exclusions == null ? Collections.<String>emptyList() : Arrays.asList(exclusions.value()); } private boolean isExcluded(URL url) throws Exception { if (!"file".equals(url.getProtocol())) { return false; } String name = new File(url.toURI()).getName(); for (String exclusion : this.exclusions) { if (this.matcher.match(exclusion, name)) { return true; } } return false; } } /** * Custom {@link TestClass} that uses a modified class path. */ private static final class ModifiedClassPathTestClass extends TestClass { private final ClassLoader classLoader; ModifiedClassPathTestClass(ClassLoader classLoader, String testClassName) throws ClassNotFoundException { super(classLoader.loadClass(testClassName)); this.classLoader = classLoader; } @Override public List<FrameworkMethod> getAnnotatedMethods( Class<? extends Annotation> annotationClass) { try { return getAnnotatedMethods(annotationClass.getName()); } catch (ClassNotFoundException ex) { throw new RuntimeException(ex); } } @SuppressWarnings("unchecked") private List<FrameworkMethod> getAnnotatedMethods(String annotationClassName) throws ClassNotFoundException { Class<? extends Annotation> annotationClass = (Class<? extends Annotation>) this.classLoader .loadClass(annotationClassName); List<FrameworkMethod> methods = super.getAnnotatedMethods(annotationClass); return wrapFrameworkMethods(methods); } private List<FrameworkMethod> wrapFrameworkMethods( List<FrameworkMethod> methods) { List<FrameworkMethod> wrapped = new ArrayList<FrameworkMethod>( methods.size()); for (FrameworkMethod frameworkMethod : methods) { wrapped.add(new ModifiedClassPathFrameworkMethod( frameworkMethod.getMethod())); } return wrapped; } private <T, E extends Throwable> T doWithModifiedClassPathThreadContextClassLoader( ModifiedClassPathTcclAction<T, E> action) throws E { ClassLoader originalClassLoader = Thread.currentThread() .getContextClassLoader(); Thread.currentThread().setContextClassLoader(this.classLoader); try { return action.perform(); } finally { Thread.currentThread().setContextClassLoader(originalClassLoader); } } /** * An action to be performed with the {@link ModifiedClassPathClassLoader} set as * the thread context class loader. */ private interface ModifiedClassPathTcclAction<T, E extends Throwable> { T perform() throws E; } /** * Custom {@link FrameworkMethod} that runs methods with * {@link ModifiedClassPathClassLoader} as the thread context class loader. */ private final class ModifiedClassPathFrameworkMethod extends FrameworkMethod { private ModifiedClassPathFrameworkMethod(Method method) { super(method); } @Override public Object invokeExplosively(final Object target, final Object... params) throws Throwable { return doWithModifiedClassPathThreadContextClassLoader( new ModifiedClassPathTcclAction<Object, Throwable>() { @Override public Object perform() throws Throwable { return ModifiedClassPathFrameworkMethod.super.invokeExplosively( target, params); } }); } } } /** * Custom {@link URLClassLoader} that modifies the class path. */ private static final class ModifiedClassPathClassLoader extends URLClassLoader { private final ClassLoader junitLoader; ModifiedClassPathClassLoader(URL[] urls, ClassLoader parent, ClassLoader junitLoader) { super(urls, parent); this.junitLoader = junitLoader; } @Override public Class<?> loadClass(String name) throws ClassNotFoundException { if (name.startsWith("org.junit") || name.startsWith("org.hamcrest")) { return this.junitLoader.loadClass(name); } return super.loadClass(name); } } }