/* * JBoss, Home of Professional Open Source * Copyright 2012, Red Hat Middleware LLC, and individual contributors * by the @authors tag. See the copyright.txt in the distribution for a * full listing of individual contributors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.jboss.arquillian.warp.impl.client.separation; import java.io.Serializable; import java.lang.reflect.Array; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.util.Arrays; import org.jboss.arquillian.warp.impl.utils.ClassLoaderUtils; import org.jboss.arquillian.warp.impl.utils.SerializationUtils; import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.classloader.ShrinkWrapClassLoader; import org.jboss.shrinkwrap.api.spec.JavaArchive; /** * Invokes given interface on instance migrated to the context of another classloader. * * @author Lukas Fryc */ public class SeparateInvocator<T> { private ClassLoader separatedClassLoader; private Class<T> clazz; private Class<?> separatedClass; private Object instance; private InvocationHandler handler; private SeparateInvocator(Class<T> clazz, ClassLoader separatedClassLoader) { this.separatedClassLoader = separatedClassLoader; this.clazz = clazz; this.instance = instantiate(); this.handler = new SeparationHandler(); } public static <I, T extends I> I invoke(Class<T> clazz, JavaArchive... classPathArchives) { JavaArchive[] copy = new JavaArchive[classPathArchives.length + 1]; System.arraycopy(classPathArchives, 0, copy, 0, classPathArchives.length); copy[copy.length - 1] = ShrinkWrap.create(JavaArchive.class).addClass(SerializationUtils.class); ClassLoader separatedClassLoader = new ShrinkWrapClassLoader(ClassLoaderUtils.getBootstrapClassLoader(), copy); return invoke(clazz, separatedClassLoader); } @SuppressWarnings("unchecked") public static <I, T extends I> I invoke(Class<T> clazz, ClassLoader separatedClassLoader) { SeparateInvocator<T> magic = new SeparateInvocator<T>(clazz, separatedClassLoader); Class<?>[] interfaces = clazz.getInterfaces(); return (T) Proxy.newProxyInstance(Thread.currentThread().getContextClassLoader(), interfaces, magic.handler); } @SuppressWarnings("unchecked") private static <R, T> R invokeStatic(ClassLoader separatedClassLoader, Class<T> clazz, Method method, Object... args) { SeparateInvocator<T> magic = new SeparateInvocator<T>(clazz, separatedClassLoader); Method adoptedMethod = magic.adoptMethod(method); Object[] adoptedArgs = magic.adaptArgs(args, Thread.currentThread().getContextClassLoader(), separatedClassLoader); try { return (R) adoptedMethod.invoke(null, adoptedArgs); } catch (Exception e) { throw new IllegalStateException("Unable to invoke static method " + method.getName() + " from class " + clazz.getName(), e); } } private class SeparationHandler implements InvocationHandler { @Override public Object invoke(Object proxy, final Method method, Object[] args) throws Throwable { final ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); final Object[] adoptedArgs = adaptArgs(args, contextClassLoader, separatedClassLoader); final Method adoptedMethod = adoptMethod(method); Object result = new InvokeSeparately<Object>() { @Override public Object invoke() { try { return adoptedMethod.invoke(instance, adoptedArgs); } catch (Exception e) { throw new IllegalStateException("Unable to invoke method separately", e); } } }.run(separatedClassLoader); Object adoptedResult = adapt(result, separatedClassLoader, contextClassLoader); return adoptedResult; } } private Method adoptMethod(Method method) { try { return separatedClass.getMethod(method.getName(), adoptMethodParameterTypes(method.getParameterTypes())); } catch (Exception e) { throw new IllegalStateException("Cannot find method " + method.getName() + " with arguments " + Arrays.asList(method.getParameterTypes()) + " on class " + separatedClass.getName() + " loaded on separated class loader"); } } private Class<?>[] adoptMethodParameterTypes(Class<?>[] parameterTypes) { Class<?>[] adopted = new Class<?>[parameterTypes.length]; for (int i = 0; i < parameterTypes.length; i++) { try { adopted[i] = adoptType(parameterTypes[i]); } catch (Exception e) { throw new IllegalStateException("Cannot adopt method parameter type " + parameterTypes[i], e); } } return adopted; } private Class<?> adoptType(Class<?> type) { try { if (type.isPrimitive()) { return type; } else if (type.isArray()) { Class<?> componentType = type.getComponentType(); Class<?> adoptedComponentType = adoptType(componentType); return Array.newInstance(adoptedComponentType, 0).getClass(); } else { return loadSeparatedClassSafely(type); } } catch (Exception e) { throw new IllegalArgumentException("Unable to adopt type of " + type.getName(), e); } } private Object instantiate() { try { separatedClass = loadSeparatedClassSafely(clazz); return separatedClass.newInstance(); } catch (Exception e) { throw new IllegalStateException( "Unable to instantiate class " + clazz.getName() + " on separated classloader", e); } } private Object[] adaptArgs(Object[] args, ClassLoader from, ClassLoader to) { Object[] adapted = new Object[args.length]; for (int i = 0; i < adapted.length; i++) { adapted[i] = adapt(args[i], from, to); } return adapted; } private Object adapt(final Object object, final ClassLoader from, final ClassLoader to) { if (from == to) { return object; } if (object.getClass().getName().startsWith("java.")) { return object; } if (object instanceof Serializable) { final Method serializeMethod = getMethodSafely(SerializationUtils.class, "serializeToBytes", Serializable.class); final Method deserializeMethod = getMethodSafely(SerializationUtils.class, "deserializeFromBytes", new byte[0].getClass()); final byte[] serialized = new InvokeSeparately<byte[]>() { @Override public byte[] invoke() { // return SerializationUtils.serializeToBytes((Serializable) object); return invokeStatic(from, SerializationUtils.class, serializeMethod, object); } }.run(from); final Object deserialized = new InvokeSeparately<Object>() { @Override public Object invoke() { // return SerializationUtils.deserializeFromBytes(serialized); return invokeStatic(to, SerializationUtils.class, deserializeMethod, serialized); } }.run(to); return deserialized; } throw new IllegalStateException("Unable to adapt instance of " + object.getClass().getName()); } private abstract class InvokeSeparately<R> { public R run(ClassLoader separatedClassLoader) { final ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader(); try { Thread.currentThread().setContextClassLoader(separatedClassLoader); return invoke(); } finally { Thread.currentThread().setContextClassLoader(originalClassLoader); } } public abstract R invoke(); } private static Method getMethodSafely(Class<?> clazz, String methodName, Class<?>... parameterTypes) { try { return clazz.getMethod(methodName, parameterTypes); } catch (Exception e) { throw new IllegalStateException("Cannot obtain method " + methodName + " from class " + clazz.getName(), e); } } private Class<?> loadSeparatedClassSafely(Class<?> clazz) { try { String className = clazz.getName(); return separatedClassLoader.loadClass(className); } catch (ClassNotFoundException e) { throw new IllegalStateException("Class " + clazz.getName() + " wasn't found on separated class loader", e); } } }