/* * JBoss, Home of Professional Open Source * Copyright 2012, Red Hat Middleware LLC, and individual contributors * by the @authors tag. See the copyright.txt in the distribution for a * full listing of individual contributors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.jboss.arquillian.warp.impl.shared.inspection; import static org.junit.Assert.assertFalse; import java.io.Serializable; import java.lang.reflect.Method; import java.util.List; import org.jboss.arquillian.warp.Inspection; import org.jboss.arquillian.warp.impl.client.separation.SeparateInvocator; import org.jboss.arquillian.warp.impl.client.separation.SeparatedClassLoader; import org.jboss.arquillian.warp.impl.client.transformation.CtClassAsset; import org.jboss.arquillian.warp.impl.client.transformation.InspectionTransformationException; import org.jboss.arquillian.warp.impl.client.transformation.MigratedInspection; import org.jboss.arquillian.warp.impl.client.transformation.NoSerialVersionUIDException; import org.jboss.arquillian.warp.impl.client.transformation.TransformedInspection; import org.jboss.arquillian.warp.impl.shared.RequestPayload; import org.jboss.arquillian.warp.impl.utils.ClassLoaderUtils; import org.jboss.arquillian.warp.impl.utils.SerializationUtils; import org.jboss.arquillian.warp.impl.utils.ShrinkWrapUtils; import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.classloader.ShrinkWrapClassLoader; import org.jboss.shrinkwrap.api.spec.JavaArchive; import org.jboss.shrinkwrap.impl.base.ServiceExtensionLoader; import org.jboss.shrinkwrap.spi.MemoryMapArchive; import org.junit.After; import org.junit.Before; import org.junit.Test; public class TestInspectionLoading { private ClassLoader originalClassLoader = null; private ClassLoader clientClassLoader; private ClassLoader serverClassLoader; @Before public void setUp() { clientClassLoader = separatedClassLoader(clientArchive()); serverClassLoader = separatedClassLoader(serverArchive()); replaceClassLoader(clientClassLoader); } @After public void replaceClassLoader() { restoreOriginalClassLoader(); } @Test public void testStaticInnerClassOnClient() throws Throwable { getStaticInnerClass(); } @Test public void testStaticInnerClassOnOnServer() throws Throwable { Object inspection = getStaticInnerClass(); testOnServer(inspection); } @Test public void testInnerClassOnClient() throws Throwable { getInnerClass(); } @Test public void testInnerClassOnOnServer() throws Throwable { Object inspection = getInnerClass(); testOnServer(inspection); } @Test public void testAnonymousClassOnClient() throws Throwable { getAnonymousClass(); } @Test public void testAnonymousClassOnOnServer() throws Throwable { Object inspection = getAnonymousClass(); testOnServer(inspection); } private void testOnServer(Object inspection) throws Throwable { try { byte[] serialized = serialize(inspection); replaceClassLoader(serverClassLoader); Object deserializedPayload = deserialize(serialized); Method getInspectionsMethod = deserializedPayload.getClass().getMethod("getInspections"); List deserializedInspectionList = (List) getInspectionsMethod.invoke(deserializedPayload); Object deserializedInspection = deserializedInspectionList.iterator().next(); Class<?> deserializedClass = deserializedInspection.getClass(); Method serverMethod = deserializedInspection.getClass().getMethod("server"); serverMethod.invoke(deserializedInspection); checkClass(deserializedClass); } finally { restoreOriginalClassLoader(); } } private void checkClass(Class<?> clazz) { // check member class invocation assertFalse(clazz.isMemberClass()); } private void replaceClassLoader(ClassLoader classLoader) { if (originalClassLoader == null) { originalClassLoader = Thread.currentThread().getContextClassLoader(); } Thread.currentThread().setContextClassLoader(classLoader); } private void restoreOriginalClassLoader() { if (originalClassLoader != null) { Thread.currentThread().setContextClassLoader(originalClassLoader); } } private Object getStaticInnerClass() throws Throwable { Class<?> clazz = clientClassLoader.loadClass(SharingClass.class.getName()); Object instance = clazz.newInstance(); Method method = clazz.getMethod("getStaticInnerClass"); // when Object shared = method.invoke(instance); return shared; } private Object getInnerClass() throws Throwable { Class<?> clazz = clientClassLoader.loadClass(SharingClass.class.getName()); Object instance = clazz.newInstance(); Method method = clazz.getMethod("getInnerClass"); // when Object shared = method.invoke(instance); return shared; } private Object getAnonymousClass() throws Throwable { Class<?> clazz = clientClassLoader.loadClass(SharingClass.class.getName()); Object instance = clazz.newInstance(); Method method = clazz.getMethod("getAnonymousClass"); // when Object shared = method.invoke(instance); return shared; } private byte[] serialize(Object object) throws Throwable { Class<?> serializationUtilsClass = serializationUtils(clientClassLoader); Method serializeToBytes = serializationUtilsClass.getMethod("serializeToBytes", Serializable.class); byte[] serialized = (byte[]) serializeToBytes.invoke(null, object); return serialized; } private Object deserialize(byte[] bytes) throws Throwable { Class<?> serializationUtilsClass = serializationUtils(serverClassLoader); Method deserializeFromBytes = serializationUtilsClass.getMethod("deserializeFromBytes", (new byte[0]).getClass()); Object deserialized = (Object) deserializeFromBytes.invoke(null, bytes); return deserialized; } private Class<?> serializationUtils(ClassLoader classLoader) throws Throwable { Class<?> serializationUtilsClass = classLoader.loadClass(SerializationUtils.class.getName()); return serializationUtilsClass; } private static JavaArchive[] clientArchive() { JavaArchive archive = ShrinkWrap .create(JavaArchive.class) .addClasses(ClientInterface.class, ClientImplementation.class) .addClasses(ServerInterface.class) .addClasses(SharingClass.class, Inspection.class, RequestPayload.class) .addClasses(TransformedInspection.class, MigratedInspection.class, InspectionTransformationException.class, NoSerialVersionUIDException.class) .addClasses(SerializationUtils.class, ShrinkWrapUtils.class, ClassLoaderUtils.class) .addClasses(SeparateInvocator.class, CtClassAsset.class, SeparatedClassLoader.class); JavaArchive javassistArchive = ShrinkWrapUtils.getJavaArchiveFromClass(javassist.CtClass.class); JavaArchive shrinkWrapSpi = ShrinkWrapUtils.getJavaArchiveFromClass(MemoryMapArchive.class); JavaArchive shrinkWrapApi = ShrinkWrapUtils.getJavaArchiveFromClass(JavaArchive.class); JavaArchive shrinkWrapImpl = ShrinkWrapUtils.getJavaArchiveFromClass(ServiceExtensionLoader.class); return new JavaArchive[] {archive, javassistArchive, shrinkWrapSpi, shrinkWrapApi, shrinkWrapImpl}; } private static JavaArchive[] serverArchive() { JavaArchive archive = ShrinkWrap.create(JavaArchive.class).addClasses(ClientInterface.class) .addClasses(ServerInterface.class, ServerImplemenation.class) .addClasses(Inspection.class, RequestPayload.class).addClasses(SerializationUtils.class); return new JavaArchive[] {archive}; } private ClassLoader separatedClassLoader(JavaArchive... archive) { return new ShrinkWrapClassLoader(ClassLoaderUtils.getBootstrapClassLoader(), archive); } }