package com.yoursway.utils.broadcaster; import static com.google.common.collect.Lists.newArrayList; import static org.objectweb.asm.Opcodes.ACC_PRIVATE; import static org.objectweb.asm.Opcodes.ACC_PUBLIC; import static org.objectweb.asm.Opcodes.ALOAD; import static org.objectweb.asm.Opcodes.ARETURN; import static org.objectweb.asm.Opcodes.ASTORE; import static org.objectweb.asm.Opcodes.CHECKCAST; import static org.objectweb.asm.Opcodes.DLOAD; import static org.objectweb.asm.Opcodes.DUP; import static org.objectweb.asm.Opcodes.FLOAD; import static org.objectweb.asm.Opcodes.GETFIELD; import static org.objectweb.asm.Opcodes.GOTO; import static org.objectweb.asm.Opcodes.IFEQ; import static org.objectweb.asm.Opcodes.ILOAD; import static org.objectweb.asm.Opcodes.INVOKEINTERFACE; import static org.objectweb.asm.Opcodes.INVOKESPECIAL; import static org.objectweb.asm.Opcodes.INVOKESTATIC; import static org.objectweb.asm.Opcodes.INVOKEVIRTUAL; import static org.objectweb.asm.Opcodes.LLOAD; import static org.objectweb.asm.Opcodes.NEW; import static org.objectweb.asm.Opcodes.PUTFIELD; import static org.objectweb.asm.Opcodes.RETURN; import static org.objectweb.asm.Opcodes.V1_5; import java.lang.reflect.Method; import java.util.HashMap; import java.util.List; import java.util.Map; import org.objectweb.asm.ClassWriter; import org.objectweb.asm.Label; import org.objectweb.asm.MethodVisitor; import com.yoursway.utils.Listeners; import com.yoursway.utils.bugs.Bugs; public class BroadcasterFactory<Listener> { private static final String IMPL_SUFFIX = "_MagicBroadcasterImpl"; private static final String FACTORY_SUFFIX = "_MagicBroadcasterFactory"; private static final String THROWABLE_NAME = "java/lang/Throwable"; private static final String OBJECT_NAME = "java/lang/Object"; private static final String OBJECT_SIG = "L" + OBJECT_NAME + ";"; private static final String ADD_OR_REMOVE_LISTENER_SIG = "(" + OBJECT_SIG + ")V"; private static final String ITERATOR_NAME = "java/util/Iterator"; private static final String ITERATOR_SIG = "L" + ITERATOR_NAME + ";"; private static final String LISTENERS_FIELD = "listeners"; private static final String LISTENERS_SIG = signatureOf(Listeners.class); private static final String BUGS_NAME = internalNameOf(Bugs.class); private static final String FACTORY_INTERFACE_NAME = internalNameOf(InternalBroadcasterFactory.class); private static final String BROADCASTER_INTERFACE_NAME = internalNameOf(Broadcaster.class); private static final String LISTENERS_NAME = internalNameOf(Listeners.class); private static final String INIT = "<init>"; private static final String[] NO_EXCEPTIONS = new String[0]; private final InternalBroadcasterFactory factory; public BroadcasterFactory(Class<Listener> listenerInterface, ClassLoader classLoader) { if (listenerInterface == null) throw new NullPointerException("listenerInterface is null"); if (classLoader == null) throw new NullPointerException("classLoader is null"); String interfaceName = internalNameOf(listenerInterface); final String factoryName = interfaceName + FACTORY_SUFFIX; final String className = interfaceName + IMPL_SUFFIX; final byte[] factoryCode = emitFactory(factoryName, className); final byte[] implCode = emitImpl(interfaceName, className, listenerInterface.getDeclaredMethods()); // ClassReader r = new ClassReader(implCode); // TraceClassVisitor v = new TraceClassVisitor(new PrintWriter(System.out)); // r.accept(v, 0); ClassLoader loader = new ClassLoader(listenerInterface.getClassLoader()) { @Override protected Class<?> findClass(String name) throws ClassNotFoundException { if (name.equals(className)) return defineClass(className.replace('/', '.'), implCode, 0, implCode.length); if (name.equals(factoryName)) return defineClass(factoryName.replace('/', '.'), factoryCode, 0, factoryCode.length); throw new ClassNotFoundException(name); } }; loadClass(loader, className); // dunno why it's needed, but it is Class<? extends InternalBroadcasterFactory> klass = loadClass(loader, factoryName); try { this.factory = klass.newInstance(); } catch (InstantiationException e) { throw new AssertionError(e); } catch (IllegalAccessException e) { throw new AssertionError(e); } } @SuppressWarnings("unchecked") public Broadcaster<Listener> newInstance() { return (Broadcaster<Listener>) factory.create(); } @SuppressWarnings("unchecked") public static <Listener> void addBroadcasterListener(Listener broadcaster, Listener listener) { ((Broadcaster) broadcaster).addListener(listener); } @SuppressWarnings("unchecked") public static <Listener> void removeBroadcasterListener(Listener broadcaster, Listener listener) { ((Broadcaster) broadcaster).removeListener(listener); } @SuppressWarnings("unchecked") private Class<? extends InternalBroadcasterFactory> loadClass(ClassLoader loader, final String className) throws AssertionError { Class<? extends InternalBroadcasterFactory> klass; try { klass = (Class<? extends InternalBroadcasterFactory>) loader.loadClass(className); } catch (ClassNotFoundException e) { throw new AssertionError(e); } return klass; } private static Map<Class<?>, BroadcasterFactory<?>> broadcasters = new HashMap<Class<?>, BroadcasterFactory<?>>(); public synchronized static <T> Broadcaster<T> newBroadcaster(Class<T> listenerInterface) { return newBroadcasterFactory(listenerInterface).newInstance(); } /** * <code>com.yoursway.utils</code> plugin must be visible from the plugin * where the given interface is defined. If not, please use an overloaded * version which allows explicit specification of a class loader. */ public synchronized static <T> BroadcasterFactory<T> newBroadcasterFactory(Class<T> listenerInterface) { return newBroadcasterFactory(listenerInterface, listenerInterface.getClassLoader()); } @SuppressWarnings("unchecked") public synchronized static <T> BroadcasterFactory<T> newBroadcasterFactory(Class<T> listenerInterface, ClassLoader classLoader) { BroadcasterFactory<T> result = (BroadcasterFactory<T>) broadcasters.get(listenerInterface); if (result == null) { result = new BroadcasterFactory(listenerInterface, classLoader); broadcasters.put(listenerInterface, result); } return result; } private static <T> byte[] emitFactory(String factoryName, String className) { ClassWriter w = new ClassWriter(0); w.visit(V1_5, ACC_PUBLIC, factoryName, null, OBJECT_NAME, new String[] { FACTORY_INTERFACE_NAME }); emitDefaultConstructor(className, w); emitFactoryMethod(className, w); w.visitEnd(); return w.toByteArray(); } private static void emitFactoryMethod(String className, ClassWriter w) { MethodVisitor mw = w.visitMethod(ACC_PUBLIC, "create", "()" + OBJECT_SIG, null, NO_EXCEPTIONS); mw.visitTypeInsn(NEW, className); mw.visitInsn(DUP); mw.visitMethodInsn(INVOKESPECIAL, className, INIT, "()V"); mw.visitInsn(ARETURN); mw.visitMaxs(2, 1); mw.visitEnd(); } private static <T> byte[] emitImpl(String interfaceName, String className, Method[] methods) { ClassWriter w = new ClassWriter(0); w.visit(V1_5, ACC_PUBLIC, className, null, OBJECT_NAME, new String[] { interfaceName, BROADCASTER_INTERFACE_NAME }); w.visitField(ACC_PRIVATE, LISTENERS_FIELD, LISTENERS_SIG, null, null).visitEnd(); emitConstructor(className, w); emitBroadcasterMethodThunk(className, "addListener", "add", w); emitBroadcasterMethodThunk(className, "removeListener", "remove", w); emitFireMethod(className, w); for (Method method : methods) { String name = method.getName(); if (name.equals("addListener") || name.equals("removeListener")) continue; emitListenerMethodThunk(interfaceName, className, w, descriptorOf(method), name, exceptionsOf(method), argumentKindsOf(method)); } w.visitEnd(); return w.toByteArray(); } private static void emitBroadcasterMethodThunk(String className, String methodName, String targetName, ClassWriter w) { MethodVisitor mw = w.visitMethod(ACC_PUBLIC, methodName, ADD_OR_REMOVE_LISTENER_SIG, null, NO_EXCEPTIONS); mw.visitVarInsn(ALOAD, 0); mw.visitFieldInsn(GETFIELD, className, LISTENERS_FIELD, LISTENERS_SIG); mw.visitVarInsn(ALOAD, 1); mw.visitMethodInsn(INVOKEVIRTUAL, LISTENERS_NAME, targetName, ADD_OR_REMOVE_LISTENER_SIG); mw.visitInsn(RETURN); mw.visitMaxs(2, 2); mw.visitEnd(); } private static void emitFireMethod(String className, ClassWriter w) { MethodVisitor mw = w.visitMethod(ACC_PUBLIC, "fire", "()" + OBJECT_SIG, null, NO_EXCEPTIONS); mw.visitVarInsn(ALOAD, 0); mw.visitInsn(ARETURN); mw.visitMaxs(1, 1); mw.visitEnd(); } private static void emitListenerMethodThunk(String interfaceName, String className, ClassWriter w, String signature, String methodName, String[] exceptions, List<TypeFamily> argumentTypes) { MethodVisitor mw = w.visitMethod(ACC_PUBLIC, methodName, signature, null, exceptions); int argCount = argumentTypes.size(); int argSize = 0; for (TypeFamily family : argumentTypes) argSize += family.size(); int firstLocal = argSize + 1; int secondLocal = firstLocal + 1; Label startLoop = new Label(); Label endLoop = new Label(); Label tryStart = new Label(); Label catchStart = new Label(); mw.visitVarInsn(ALOAD, 0); mw.visitFieldInsn(GETFIELD, className, LISTENERS_FIELD, LISTENERS_SIG); mw.visitMethodInsn(INVOKEVIRTUAL, LISTENERS_NAME, "iterator", "()" + ITERATOR_SIG); mw.visitVarInsn(ASTORE, firstLocal); mw.visitLabel(startLoop); mw.visitVarInsn(ALOAD, firstLocal); mw.visitMethodInsn(INVOKEINTERFACE, ITERATOR_NAME, "hasNext", "()Z"); mw.visitJumpInsn(IFEQ, endLoop); mw.visitVarInsn(ALOAD, firstLocal); mw.visitMethodInsn(INVOKEINTERFACE, ITERATOR_NAME, "next", "()" + OBJECT_SIG); mw.visitTypeInsn(CHECKCAST, interfaceName); mw.visitVarInsn(ASTORE, secondLocal); mw.visitLabel(tryStart); mw.visitVarInsn(ALOAD, secondLocal); for (int arg = 1; arg <= argCount; arg += argumentTypes.get(arg - 1).size()) mw.visitVarInsn(argumentTypes.get(arg - 1).loadCommand(), arg); mw.visitMethodInsn(INVOKEINTERFACE, interfaceName, methodName, signature); mw.visitJumpInsn(GOTO, startLoop); mw.visitLabel(catchStart); mw.visitTryCatchBlock(tryStart, catchStart, catchStart, THROWABLE_NAME); mw.visitVarInsn(ALOAD, secondLocal); mw.visitMethodInsn(INVOKESTATIC, BUGS_NAME, "listenerFailed", "(Ljava/lang/Throwable;Ljava/lang/Object;)V"); mw.visitJumpInsn(GOTO, startLoop); mw.visitLabel(endLoop); mw.visitInsn(RETURN); mw.visitMaxs(Math.max(2, argSize + 1), argSize + 3); mw.visitEnd(); } private static void emitConstructor(String className, ClassWriter w) { MethodVisitor cw = w.visitMethod(ACC_PUBLIC, INIT, "()V", null, NO_EXCEPTIONS); cw.visitVarInsn(ALOAD, 0); cw.visitMethodInsn(INVOKESPECIAL, OBJECT_NAME, INIT, "()V"); cw.visitVarInsn(ALOAD, 0); cw.visitTypeInsn(NEW, LISTENERS_NAME); cw.visitInsn(DUP); cw.visitMethodInsn(INVOKESPECIAL, LISTENERS_NAME, INIT, "()V"); cw.visitFieldInsn(PUTFIELD, className, LISTENERS_FIELD, LISTENERS_SIG); cw.visitInsn(RETURN); cw.visitMaxs(3, 1); cw.visitEnd(); } private static void emitDefaultConstructor(String className, ClassWriter w) { MethodVisitor cw = w.visitMethod(ACC_PUBLIC, INIT, "()V", null, NO_EXCEPTIONS); cw.visitVarInsn(ALOAD, 0); cw.visitMethodInsn(INVOKESPECIAL, OBJECT_NAME, INIT, "()V"); cw.visitInsn(RETURN); cw.visitMaxs(1, 1); cw.visitEnd(); } enum TypeFamily { OBJECT { @Override public int loadCommand() { return ALOAD; } }, INT { @Override public int loadCommand() { return ILOAD; } }, LONG { @Override public int loadCommand() { return LLOAD; } @Override public int size() { return 2; } }, FLOAT { @Override public int loadCommand() { return FLOAD; } }, DOUBLE { @Override public int loadCommand() { return DLOAD; } @Override public int size() { return 2; } }, ; public abstract int loadCommand(); public int size() { return 1; } public static TypeFamily fromClass(Class<?> klass) { String name = klass.getName(); if (name.startsWith("[")) return OBJECT; // not implemented if (name.equals("boolean")) return INT; if (name.equals("byte")) return INT; if (name.equals("char")) return INT; if (name.equals("double")) return DOUBLE; if (name.equals("float")) return FLOAT; if (name.equals("int")) return INT; if (name.equals("long")) return LONG; if (name.equals("short")) return INT; return OBJECT; } } public static String signatureOf(Class<?> klass) { String name = klass.getName(); if (name.startsWith("[")) return name; if (name.equals("boolean")) return "Z"; if (name.equals("byte")) return "B"; if (name.equals("char")) return "C"; if (name.equals("double")) return "D"; if (name.equals("float")) return "F"; if (name.equals("int")) return "I"; if (name.equals("long")) return "J"; if (name.equals("short")) return "S"; if (name.equals("void")) return "V"; return "L" + internalNameOf(klass) + ";"; } private static String internalNameOf(Class<?> klass) { return internalNameOf(klass.getName()); } private static String internalNameOf(String name) { return name.replace('.', '/'); } private static String descriptorOf(Method method) { StringBuilder result = new StringBuilder(); result.append('('); for (Class<?> klass : method.getParameterTypes()) result.append(signatureOf(klass)); result.append(')'); result.append(signatureOf(method.getReturnType())); return result.toString(); } private static List<TypeFamily> argumentKindsOf(Method method) { List<TypeFamily> result = newArrayList(); for (Class<?> klass : method.getParameterTypes()) result.add(TypeFamily.fromClass(klass)); return result; } private static String[] exceptionsOf(Method method) { Class<?>[] types = method.getExceptionTypes(); String[] result = new String[types.length]; for (int i = 0; i < types.length; i++) result[i] = internalNameOf(types[i]); return result; } }