package org.mockitoutil; import java.net.MalformedURLException; import java.net.URL; import java.net.URLClassLoader; import java.util.HashMap; import java.util.Map; import java.util.concurrent.Callable; /** * Custom classloader to load classes in hierarchic realm. * * Each class can be reloaded in the realm if the LoadClassPredicate says so. */ public class SimplePerRealmReloadingClassLoader extends URLClassLoader { private final Map<String,Class<?>> classHashMap = new HashMap<String, Class<?>>(); private ReloadClassPredicate reloadClassPredicate; public SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate) { super(getPossibleClassPathsUrls()); this.reloadClassPredicate = reloadClassPredicate; } public SimplePerRealmReloadingClassLoader(ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate) { super(getPossibleClassPathsUrls(), parentClassLoader); this.reloadClassPredicate = reloadClassPredicate; } private static URL[] getPossibleClassPathsUrls() { return new URL[]{ obtainClassPath(), obtainClassPath("org.mockito.Mockito"), obtainClassPath("net.bytebuddy.ByteBuddy") }; } private static URL obtainClassPath() { String className = SimplePerRealmReloadingClassLoader.class.getName(); return obtainClassPath(className); } private static URL obtainClassPath(String className) { String path = className.replace('.', '/') + ".class"; String url = SimplePerRealmReloadingClassLoader.class.getClassLoader().getResource(path).toExternalForm(); try { return new URL(url.substring(0, url.length() - path.length())); } catch (MalformedURLException e) { throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e); } } @Override public Class<?> loadClass(String qualifiedClassName) throws ClassNotFoundException { if(reloadClassPredicate.acceptReloadOf(qualifiedClassName)) { // return customLoadClass(qualifiedClassName); // Class<?> loadedClass = findLoadedClass(qualifiedClassName); if(!classHashMap.containsKey(qualifiedClassName)) { Class<?> foundClass = findClass(qualifiedClassName); saveFoundClass(qualifiedClassName, foundClass); return foundClass; } return classHashMap.get(qualifiedClassName); } return useParentClassLoaderFor(qualifiedClassName); } private void saveFoundClass(String qualifiedClassName, Class<?> foundClass) { classHashMap.put(qualifiedClassName, foundClass); } private Class<?> useParentClassLoaderFor(String qualifiedName) throws ClassNotFoundException { return super.loadClass(qualifiedName); } public Object doInRealm(String callableCalledInClassLoaderRealm) throws Exception { ClassLoader current = Thread.currentThread().getContextClassLoader(); try { Thread.currentThread().setContextClassLoader(this); Object instance = this.loadClass(callableCalledInClassLoaderRealm).getConstructor().newInstance(); if (instance instanceof Callable) { Callable<?> callableInRealm = (Callable<?>) instance; return callableInRealm.call(); } } finally { Thread.currentThread().setContextClassLoader(current); } throw new IllegalArgumentException("qualified name '" + callableCalledInClassLoaderRealm + "' should represent a class implementing Callable"); } public Object doInRealm(String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args) throws Exception { ClassLoader current = Thread.currentThread().getContextClassLoader(); try { Thread.currentThread().setContextClassLoader(this); Object instance = this.loadClass(callableCalledInClassLoaderRealm).getConstructor(argTypes).newInstance(args); if (instance instanceof Callable) { Callable<?> callableInRealm = (Callable<?>) instance; return callableInRealm.call(); } } finally { Thread.currentThread().setContextClassLoader(current); } throw new IllegalArgumentException("qualified name '" + callableCalledInClassLoaderRealm + "' should represent a class implementing Callable"); } public interface ReloadClassPredicate { boolean acceptReloadOf(String qualifiedName); } }