package org.hotswap.agent.plugin.proxy.test.methods; import static org.hotswap.agent.plugin.proxy.test.util.HotSwapTestHelper.*; import static org.junit.Assert.*; import java.io.IOException; import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import org.junit.Before; import org.junit.Test; import org.springframework.cglib.proxy.Callback; import org.springframework.cglib.proxy.Enhancer; import org.springframework.cglib.proxy.MethodInterceptor; import org.springframework.cglib.proxy.MethodProxy; public class AddEnhancerMethodProxyTest { // Version 0 public static class AImpl implements A { @Override public int getValue1() { return 1; } } // Version 0 public static class AImpl___0 implements A___0 { @Override public int getValue1() { return 1; } } // Version 1 public static class AImpl___1 implements A___1 { @Override public int getValue2() { return 2; } } // Version 2 public static class AImpl___2 implements A___2 { @Override public int getValue3() { return 3; } } // Version 0 public interface A { public int getValue1(); } // Version 0 public interface A___0 { public int getValue1(); } // Version 1 public interface A___1 { public int getValue2(); } // Version 2 public interface A___2 { public int getValue3(); } @Before public void setUp() throws Exception { __toVersion__Delayed(0); } @Test public void addMethodToInterfaceAndImplementation() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { assert __version__() == 0; final A a = new AImpl(); assertEquals(1, a.getValue1()); __toVersion__Delayed(1); Method method = getMethod(a, "getValue2"); assertEquals(2, method.invoke(a, null)); } public static class SerializableNoOp implements Serializable, MethodInterceptor { private int count; private AImpl obj = new AImpl(); @Override public Object intercept(Object proxy, Method method, Object[] args, MethodProxy methodProxy) throws Throwable { if (method.getName().startsWith("getValue")) count++; return method.invoke(obj, args); } public int getInvocationCount() { return count; } } @Test public void accessNewMethodOnProxy() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { assert __version__() == 0; SerializableNoOp cb = new SerializableNoOp(); final A a = createEnhancer(cb); assertEquals(0, cb.getInvocationCount()); assertEquals(1, a.getValue1()); assertEquals(1, cb.getInvocationCount()); __toVersion__Delayed(1); Method method = getMethod(a, "getValue2"); assertEquals("getValue2", method.getName()); assertEquals(1, cb.getInvocationCount()); assertEquals(2, method.invoke(a, null)); assertEquals(2, cb.getInvocationCount()); __toVersion__Delayed(2); method = getMethod(a, "getValue3"); assertEquals("getValue3", method.getName()); assertEquals(2, cb.getInvocationCount()); assertEquals(3, method.invoke(a, null)); assertEquals(3, cb.getInvocationCount()); } private A createEnhancer(Callback cb) { Enhancer enhancer = new Enhancer(); enhancer.setSuperclass(AImpl.class); enhancer.setCallback(cb); final A a = (A) enhancer.create(); return a; } @Test public void accessNewMethodOnProxyCreatedAfterSwap() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException { assert __version__() == 0; SerializableNoOp cb = new SerializableNoOp(); A a = createEnhancer(cb); assertEquals(0, cb.getInvocationCount()); assertEquals(1, a.getValue1()); assertEquals(1, cb.getInvocationCount()); __toVersion__Delayed(1); a = createEnhancer(cb); Method method = getMethod(a, "getValue2"); assertEquals("getValue2", method.getName()); assertEquals(1, cb.getInvocationCount()); assertEquals(2, method.invoke(a, null)); assertEquals(2, cb.getInvocationCount()); } private Method getMethod(Object a, String methodName) { Method[] declaredMethods = a.getClass().getMethods(); Method m = null; for (Method method : declaredMethods) { if (method.getName().equals(methodName)) m = method; } if (m == null) { fail(a.getClass().getSimpleName() + " does not have method " + methodName); } return m; } }