/** * Copyright (C) 2009-2013 Barchart, Inc. <http://www.barchart.com/> * * All rights reserved. Licensed under the OSI BSD License. * * http://www.opensource.org/licenses/bsd-license.php */ package com.barchart.udt; import static org.junit.Assert.*; import java.io.DataInputStream; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.util.HashMap; import java.util.Map; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class TestLoadUnload { private static class JNIClassLoader extends ClassLoader { final Map<String, Class<?>> classes; File classLoadRoot; JNIClassLoader(final File classLoadRoot) { super(JNIClassLoader.class.getClassLoader()); classes = new HashMap<String, Class<?>>(); this.classLoadRoot = classLoadRoot; TestLoadUnload.isLoaderPresent = true; } @Override public String toString() { return JNIClassLoader.class.getName(); } @Override public Class<?> loadClass(final String name) throws ClassNotFoundException { return loadClass(name, false); } @Override protected Package getPackage(final String name) { if (name.equals("com.barchart.udt")) return this.getClass().getPackage(); else return super.getPackage(name); } @Override protected synchronized Class<?> loadClass(final String name, final boolean resolve) throws ClassNotFoundException { if (name.startsWith("com.barchart.udt.")) return findClass(name); else return super.loadClass(name, resolve); } @Override public Class<?> findClass(final String name) throws ClassNotFoundException { log.info(String.format("Attempting to find class %s", name)); if (classes.containsKey(name)) { return classes.get(name); } final String path = name.replace('.', File.separatorChar) + ".class"; byte[] b = null; try { b = loadClassData(path); } catch (final IOException e) { throw new ClassNotFoundException("Class not found at path: " + new File(name).getAbsolutePath(), e); } final Class<?> c = defineClass(name, b, 0, b.length); resolveClass(c); classes.put(name, c); return c; } private byte[] loadClassData(final String name) throws IOException { final File file = new File(classLoadRoot, name); final int size = (int) file.length(); final byte buff[] = new byte[size]; final DataInputStream in = // new DataInputStream(new FileInputStream(file)); in.readFully(buff); in.close(); return buff; } @Override protected void finalize() throws Throwable { log.info("Finalised {}", this.getClass()); super.finalize(); TestLoadUnload.isLoaderPresent = false; } } private static final Logger log = LoggerFactory .getLogger(TestLoadUnload.class); private static volatile boolean isLoaderPresent; @SuppressWarnings({ "unchecked", "rawtypes" }) private int createSocketWithClassloader() throws ClassNotFoundException, IllegalArgumentException, SecurityException, InstantiationException, IllegalAccessException, InvocationTargetException, NoSuchMethodException, InterruptedException { final String currentDir = new File(".").getAbsolutePath(); log.info(String.format("Current directory: %s", currentDir)); // HACK - ideally we should update JNIClassLoader to // use the whole java classpath final File classPath = new File("target/classes"); JNIClassLoader loader = new JNIClassLoader(classPath); Class<Enum> typeClass = (Class<Enum>) loader .findClass("com.barchart.udt.TypeUDT"); Class<?> socketClass = loader.findClass("com.barchart.udt.SocketUDT"); Object socketInstance = socketClass.getDeclaredConstructor(typeClass) .newInstance(Enum.valueOf(typeClass, "DATAGRAM")); socketClass.getMethod("cleanup").invoke(socketInstance); final int classHashCode = socketClass.hashCode(); log.info(String.format("socketClass hashCode : %s", classHashCode)); typeClass = null; socketClass = null; socketInstance = null; loader = null; /** waiting loader to finalize */ while (isLoaderPresent) { System.gc(); Thread.sleep(100); } /** waiting jvm to unlink class */ System.gc(); Thread.sleep(100); return classHashCode; } @Test(timeout = 15 * 1000) public void testLoadUnload() throws Exception { int lastClassHash = 0; { final int nextClassHash = createSocketWithClassloader(); assertTrue(nextClassHash != lastClassHash); lastClassHash = nextClassHash; } { final int nextClassHash = createSocketWithClassloader(); assertTrue(nextClassHash != lastClassHash); lastClassHash = nextClassHash; } { final int nextClassHash = createSocketWithClassloader(); assertTrue(nextClassHash != lastClassHash); lastClassHash = nextClassHash; } } public static void main(final String[] args) throws Throwable { final TestLoadUnload test = new TestLoadUnload(); test.testLoadUnload(); } }