/* * JBoss, Home of Professional Open Source. * Copyright 2008, Red Hat Middleware LLC, and individual contributors * as indicated by the @author tags. See the copyright.txt file in the * distribution for a full listing of individual contributors. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of * the License, or (at your option) any later version. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, write to the Free * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA * 02110-1301 USA, or see the FSF site: http://www.fsf.org. */ package org.jboss.test.cluster.defaultcfg.test; import java.lang.reflect.Method; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import junit.framework.TestCase; import org.jboss.ha.framework.server.JChannelFactory; import org.jboss.logging.Logger; import org.jgroups.Address; import org.jgroups.Channel; import org.jgroups.MembershipListener; import org.jgroups.Message; import org.jgroups.MessageListener; import org.jgroups.View; import org.jgroups.blocks.GroupRequest; import org.jgroups.blocks.MessageDispatcher; import org.jgroups.blocks.RequestHandler; /** * Tests classloader leak handling of {@link JChannelFactory}. * * @author Brian Stansberry */ public class JChannelFactoryClassLoaderLeakTestCase extends TestCase { private static final Logger log = Logger.getLogger(JChannelFactoryClassLoaderLeakTestCase.class); private static Method OBJECT_ARG = null; private static Method STRING_ARG = null; private static Method SIMPLE_MUX = null; private static Method COMPLEX_MUX = null; static { Class clazz = JChannelFactory.class; try { OBJECT_ARG = clazz.getDeclaredMethod("createChannel", new Class[] { Object.class }); STRING_ARG = clazz.getDeclaredMethod("createChannel", new Class[] { String.class }); SIMPLE_MUX = clazz.getDeclaredMethod("createMultiplexerChannel", new Class[] { String.class, String.class }); COMPLEX_MUX = clazz.getDeclaredMethod("createMultiplexerChannel", new Class[] { String.class, String.class, boolean.class, String.class }); } catch (NoSuchMethodException nsme) { log.error("Reflection failure", nsme); } } private JChannelFactory factory1; private JChannelFactory factory2; private Channel channel1; private Channel channel2; private String jgroups_bind_addr; private ClassLoader testLoader; /** * Create a new JChannelFactoryUnitTestCase. * * @param name */ public JChannelFactoryClassLoaderLeakTestCase(String name) { super(name); } protected void setUp() throws Exception { if (COMPLEX_MUX == null) throw new IllegalStateException("Reflection failed in class init; see logs"); super.setUp(); testLoader = new ClassLoader(Thread.currentThread().getContextClassLoader()){}; String jgroups_bind_addr = System.getProperty("jgroups.bind_addr"); if (jgroups_bind_addr == null) { System.setProperty("jbosstest.cluster.node0", System.getProperty("jbosstest.cluster.node0", "localhost")); } factory1 = new TestClassLoaderJChannelFactory(); factory1.setMultiplexerConfig("cluster/channelfactory/stacks.xml"); factory1.setAssignLogicalAddresses(false); factory1.setExposeChannels(false); factory1.setManageReleasedThreadClassLoader(true); factory1.create(); factory1.start(); factory2 = new TestClassLoaderJChannelFactory(); factory2.setMultiplexerConfig("cluster/channelfactory/stacks.xml"); factory2.setAssignLogicalAddresses(false); factory2.setExposeChannels(false); factory2.setManageReleasedThreadClassLoader(true); factory2.create(); factory2.start(); } protected void tearDown() throws Exception { super.tearDown(); testLoader = null; if (jgroups_bind_addr == null) System.clearProperty("jgroups.bind_addr"); if (channel1 != null && channel1.isOpen()) channel1.close(); if (channel2 != null && channel2.isOpen()) channel2.close(); if (factory1 != null) { factory1.stop(); factory1.destroy(); } if (factory2 != null) { factory2.stop(); factory2.destroy(); } } public void testClassLoaderLeakObjectShared() throws Exception { Object[] args1 = { factory1.getConfig("shared1") }; Object[] args2 = { factory2.getConfig("shared2") }; classloaderLeakTest(OBJECT_ARG, args1, args2); } public void testClassLoaderLeakObjectUnshared() throws Exception { Object[] args1 = { factory1.getConfig("unshared1") }; Object[] args2 = { factory2.getConfig("unshared2") }; classloaderLeakTest(OBJECT_ARG, args1, args2); } public void testClassLoaderLeakStringShared() throws Exception { Object[] args1 = { "shared1" }; Object[] args2 = { "shared2" }; classloaderLeakTest(STRING_ARG, args1, args2); } public void testClassLoaderLeakStringUnshared() throws Exception { Object[] args1 = { "unshared1" }; Object[] args2 = { "unshared2" }; classloaderLeakTest(STRING_ARG, args1, args2); } public void testClassLoaderLeakSimpleMuxShared() throws Exception { Object[] args1 = { "shared1", "leaktest" }; Object[] args2 = { "shared2", "leaktest" }; classloaderLeakTest(SIMPLE_MUX, args1, args2); } public void testClassLoaderLeakSimpleMuxUnshared() throws Exception { Object[] args1 = { "unshared1", "leaktest" }; Object[] args2 = { "unshared2", "leaktest" }; classloaderLeakTest(SIMPLE_MUX, args1, args2); } public void testClassLoaderLeakComplexMuxShared() throws Exception { Object[] args1 = { "shared1", "leaktest", Boolean.FALSE, null }; Object[] args2 = { "shared2", "leaktest", Boolean.FALSE, null }; classloaderLeakTest(COMPLEX_MUX, args1, args2); } public void testClassLoaderLeakComplexMuxUnshared() throws Exception { Object[] args1 = { "unshared1", "leaktest", Boolean.FALSE, null }; Object[] args2 = { "unshared2", "leaktest", Boolean.FALSE, null }; classloaderLeakTest(COMPLEX_MUX, args1, args2); } public void testClassLoaderLeakNonConcurrent() throws Exception { Object[] args1 = { "nonconcurrent1" }; Object[] args2 = { "nonconcurrent2" }; classloaderLeakTest(STRING_ARG, args1, args2); } private void classloaderLeakTest(Method factoryMeth, Object[] factory1Args, Object[] factory2Args) throws Exception { int numThreads = 8; int numLoops = 100; Semaphore semaphore = new Semaphore(numThreads); ThreadGroup runnerGroup = new ThreadGroup("TestRunners"); ClassLoader ours = Thread.currentThread().getContextClassLoader(); // The classloader we want channel threads to use ClassLoaderLeakHandler handler = new ClassLoaderLeakHandler(testLoader, semaphore, runnerGroup); MessageDispatcher[] dispatchers = new MessageDispatcher[2]; // Thread.currentThread().setContextClassLoader(testLoader); // try // { channel1 = (Channel) factoryMeth.invoke(factory1, factory1Args); dispatchers[0] = new MessageDispatcher(channel1, handler, handler, handler); channel1.connect("leaktest"); assertEquals("No classloader leak on channel1 connect", null, handler.getLeakedClassLoader()); channel2 = (Channel) factoryMeth.invoke(factory2, factory2Args); dispatchers[1] = new MessageDispatcher(channel2, handler, handler, handler); channel2.connect("leaktest"); assertEquals("No classloader leak on channel2 connect", null, handler.getLeakedClassLoader()); // } // finally // { // Thread.currentThread().setContextClassLoader(ours); // } log.info("Channels connected"); ClassLoaderLeakRunner[] runners = new ClassLoaderLeakRunner[numThreads]; for (int i = 0; i < runners.length; i++) { MessageDispatcher disp = dispatchers[i % 2]; runners[i] = new ClassLoaderLeakRunner(disp, numLoops, runnerGroup, semaphore); } semaphore.acquire(numThreads); for (int i = 0; i < runners.length; i++) { runners[i].start(); } semaphore.release(numThreads); try { assertTrue("messages received within 15 seconds", semaphore.tryAcquire(numThreads, 15, TimeUnit.SECONDS)); log.info("Messages received"); } finally { for (int i = 0; i < runners.length; i++) { runners[i].stop(); } } log.info("Sender threads stopped"); assertEquals("No classloader leak", null, handler.getLeakedClassLoader()); } private class ClassLoaderLeakRunner implements Runnable { private Thread thread; private final MessageDispatcher dispatcher; private final int numMsgs; private final ThreadGroup threadGroup; private final Semaphore semaphore; private boolean stopped; private Exception exception; ClassLoaderLeakRunner(MessageDispatcher dispatcher, int numMsgs, ThreadGroup group, Semaphore semaphore) { this.dispatcher = dispatcher; this.numMsgs = numMsgs; this.threadGroup = group; this.semaphore = semaphore; } public void run() { boolean acquired = false; ClassLoader cl = new ClassLoader(Thread.currentThread().getContextClassLoader()){}; try { semaphore.acquire(); acquired = true; log.info(Thread.currentThread().getName() + " starting"); Thread.currentThread().setContextClassLoader(cl); for (int i = 0; i < numMsgs && !stopped && !Thread.interrupted(); i++) { Message msg = new Message(null, null, String.valueOf(i)); // sending this way calls receive() dispatcher.send(msg); // sending this way calls handle() dispatcher.castMessage(null, msg, GroupRequest.GET_ALL, 0, false); } log.info(Thread.currentThread().getName() + " done"); } catch (Exception e) { this.exception = e; } finally { if (acquired) semaphore.release(); Thread.currentThread().setContextClassLoader(cl.getParent()); } } public Exception getException() { return exception; } public void start() { thread = new Thread(this.threadGroup, this); thread.setDaemon(true); thread.start(); } public void stop() { stopped = true; if (thread != null && thread.isAlive()) { try { thread.join(100); } catch (InterruptedException e) { } if (thread.isAlive()) thread.interrupt(); } } } private class ClassLoaderLeakHandler implements MembershipListener, MessageListener, RequestHandler { private final Semaphore semaphore; private final ClassLoader expected; private final ThreadGroup runnerGroup; private final Thread main; private ClassLoader leakedClassLoader; private final int numPermits; ClassLoaderLeakHandler(ClassLoader expected, Semaphore semaphore, ThreadGroup runnerGroup) { this.expected = expected; this.semaphore = semaphore; this.runnerGroup = runnerGroup; this.numPermits = this.semaphore.availablePermits(); this.main = Thread.currentThread(); } public Object handle(Message msg) { log.debug("handled(): " + msg.getObject()); checkClassLoader(true, "handle()"); return null; } public void block() { checkClassLoader(false, "block()"); } public void suspect(Address suspected_mbr) { checkClassLoader(false, "suspect()"); } public void viewAccepted(View new_view) { checkClassLoader(false, "viewAccepted()"); log.info("viewAccepted(): " + new_view); } public byte[] getState() { checkClassLoader(false, "getState()"); return new byte[1]; } public void receive(Message msg) { checkClassLoader(false, "receive()"); } public void setState(byte[] state) { checkClassLoader(false, "setState()"); } public ClassLoader getLeakedClassLoader() { return leakedClassLoader; } private void checkClassLoader(boolean fromHandle, String method) { if (leakedClassLoader == null) // ignore msgs once we found a leak { // ignore runner threads that loop all the way back up Thread current = Thread.currentThread(); if (current == main || current.getThreadGroup().equals(runnerGroup)) { return; } ClassLoader tccl = Thread.currentThread().getContextClassLoader(); if (!expected.equals(tccl)) { leakedClassLoader = tccl; semaphore.release(numPermits); log.info("ClassLoader leaked in " + method + ": " + tccl + " leaked to " + Thread.currentThread().getName()); } } } } private class TestClassLoaderJChannelFactory extends JChannelFactory { @Override protected ClassLoader getDefaultChannelThreadContextClassLoader() { return testLoader; } } }