/* * JBoss, Home of Professional Open Source * Copyright 2012, Red Hat Middleware LLC, and individual contributors * by the @authors tag. See the copyright.txt in the distribution for a * full listing of individual contributors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.jboss.arquillian.warp.impl.testutils; import java.lang.annotation.Annotation; import java.lang.reflect.Array; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Arrays; import java.util.List; import java.util.logging.Logger; import org.jboss.arquillian.warp.impl.utils.ClassLoaderUtils; import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.classloader.ShrinkWrapClassLoader; import org.jboss.shrinkwrap.api.spec.JavaArchive; import org.junit.Test; import org.junit.runners.BlockJUnit4ClassRunner; import org.junit.runners.model.FrameworkMethod; import org.junit.runners.model.InitializationError; import org.junit.runners.model.Statement; public class SeparatedClassloaderRunner extends BlockJUnit4ClassRunner { private static String ARCHIVE_MESSAGE = "There must be exactly one no-arg static method which returns JavaArchive with annotation @SeparatedClassPath defined"; private static Logger log = Logger.getLogger(SeparatedClassloaderRunner.class.getName()); private static ThreadLocal<ClassLoader> initializedClassLoader = new ThreadLocal<ClassLoader>(); private ClassLoader classLoader; private ClassLoader originalClassLoader; public SeparatedClassloaderRunner(Class<?> testClass) throws InitializationError { super(getFromTestClassloader(initializeClassLoader(testClass), testClass)); this.classLoader = initializedClassLoader.get(); initializedClassLoader.set(null); } @Override protected Statement withBeforeClasses(Statement statement) { Statement original = super.withBeforeClasses(statement); Statement backupAndReplaceClassLoader = new Statement() { @Override public void evaluate() throws Throwable { if (originalClassLoader == null) { originalClassLoader = Thread.currentThread().getContextClassLoader(); } Thread.currentThread().setContextClassLoader(classLoader); } }; return new ComposedStatement(backupAndReplaceClassLoader, original); } @Override protected Statement withAfterClasses(Statement statement) { Statement original = super.withAfterClasses(statement); Statement restoreOriginalClassLoader = new Statement() { @Override public void evaluate() throws Throwable { if (originalClassLoader != null) { Thread.currentThread().setContextClassLoader(originalClassLoader); } } }; return new ComposedStatement(original, restoreOriginalClassLoader); } private static boolean checkClassPathMethodType(Method method) { if (!Modifier.isStatic(method.getModifiers())) { return false; } if (method.getParameterTypes().length != 0) { return false; } if (method.getReturnType().isAssignableFrom(JavaArchive.class)) { return true; } if (method.getReturnType().isAssignableFrom(Array.newInstance(JavaArchive.class, 0).getClass())) { return true; } return false; } static ClassLoader initializeClassLoader(Class<?> testClass) throws InitializationError { List<Method> classPath = SecurityActions.getMethodsWithAnnotation(testClass, SeparatedClassPath.class); if (classPath.isEmpty()) { throw new InitializationError(ARCHIVE_MESSAGE); } Method method = classPath.iterator().next(); if (!checkClassPathMethodType(method)) { throw new InitializationError(ARCHIVE_MESSAGE); } JavaArchive[] archives; try { Object result = classPath.get(0).invoke(null); if (result instanceof JavaArchive) { archives = new JavaArchive[] {(JavaArchive) result}; } else { archives = (JavaArchive[]) result; } } catch (Exception e) { throw new IllegalStateException("Failed to retrieve @SeparatedClassPath archive", e); } ClassLoader shrinkWrapClassLoader = getSeparatedClassLoader(archives, testClass); initializedClassLoader.set(shrinkWrapClassLoader); return initializedClassLoader.get(); } static Class<?> getFromTestClassloader(ClassLoader classLoader, Class<?> clazz) throws InitializationError { final String className = clazz.getName(); try { Class<?> loadedClazz = classLoader.loadClass(className); log.info("Loaded test class: " + className); return loadedClazz; } catch (Throwable e) { e.printStackTrace(); throw new InitializationError(e); } } private static ClassLoader getSeparatedClassLoader(JavaArchive[] archives, Class<?> testClass) throws InitializationError { try { ClassLoader bootstrapClassLoader = ClassLoaderUtils.getBootstrapClassLoader(); JavaArchive baseArchive = ShrinkWrap.create(JavaArchive.class); // JUnit baseArchive.addClasses(Test.class); // ShrinkWrap - JavaArchive baseArchive.addClasses(SecurityActions.getAncestors(JavaArchive.class)); // testClass baseArchive.addClasses(SecurityActions.getAncestors(testClass)); archives = Arrays.copyOf(archives, archives.length + 1); archives[archives.length - 1] = baseArchive; ShrinkWrapClassLoader shrinkwrapClassLoader = new ShrinkWrapClassLoader(bootstrapClassLoader, archives); return shrinkwrapClassLoader; } catch (Exception e) { throw new InitializationError(e); } } @Override protected List<FrameworkMethod> computeTestMethods() { if (classLoader == null) { classLoader = initializedClassLoader.get(); } if (classLoader == null) { throw new IllegalStateException("classLoader must not be null in this state"); } try { Class<? extends Annotation> testAnnotation = (Class<? extends Annotation>) classLoader.loadClass(Test.class .getName()); return getTestClass().getAnnotatedMethods(testAnnotation); } catch (ClassNotFoundException e) { throw new IllegalStateException(e); } } private static class ComposedStatement extends Statement { private Statement[] statements; public ComposedStatement(Statement... statements) { this.statements = statements; } @Override public void evaluate() throws Throwable { for (Statement statement : statements) { statement.evaluate(); } } } }