package org.hotswap.agent.plugin.proxy.test.methods;
import static org.hotswap.agent.plugin.proxy.test.util.HotSwapTestHelper.*;
import static org.junit.Assert.*;
import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import org.junit.Before;
import org.junit.Test;
import org.springframework.cglib.proxy.Callback;
import org.springframework.cglib.proxy.Enhancer;
import org.springframework.cglib.proxy.MethodInterceptor;
import org.springframework.cglib.proxy.MethodProxy;
public class AddEnhancerMethodProxyTest {
// Version 0
public static class AImpl implements A {
@Override
public int getValue1() {
return 1;
}
}
// Version 0
public static class AImpl___0 implements A___0 {
@Override
public int getValue1() {
return 1;
}
}
// Version 1
public static class AImpl___1 implements A___1 {
@Override
public int getValue2() {
return 2;
}
}
// Version 2
public static class AImpl___2 implements A___2 {
@Override
public int getValue3() {
return 3;
}
}
// Version 0
public interface A {
public int getValue1();
}
// Version 0
public interface A___0 {
public int getValue1();
}
// Version 1
public interface A___1 {
public int getValue2();
}
// Version 2
public interface A___2 {
public int getValue3();
}
@Before
public void setUp() throws Exception {
__toVersion__Delayed(0);
}
@Test
public void addMethodToInterfaceAndImplementation() throws IllegalAccessException, IllegalArgumentException,
InvocationTargetException {
assert __version__() == 0;
final A a = new AImpl();
assertEquals(1, a.getValue1());
__toVersion__Delayed(1);
Method method = getMethod(a, "getValue2");
assertEquals(2, method.invoke(a, null));
}
public static class SerializableNoOp implements Serializable, MethodInterceptor {
private int count;
private AImpl obj = new AImpl();
@Override
public Object intercept(Object proxy, Method method, Object[] args, MethodProxy methodProxy) throws Throwable {
if (method.getName().startsWith("getValue"))
count++;
return method.invoke(obj, args);
}
public int getInvocationCount() {
return count;
}
}
@Test
public void accessNewMethodOnProxy() throws IllegalAccessException, IllegalArgumentException,
InvocationTargetException {
assert __version__() == 0;
SerializableNoOp cb = new SerializableNoOp();
final A a = createEnhancer(cb);
assertEquals(0, cb.getInvocationCount());
assertEquals(1, a.getValue1());
assertEquals(1, cb.getInvocationCount());
__toVersion__Delayed(1);
Method method = getMethod(a, "getValue2");
assertEquals("getValue2", method.getName());
assertEquals(1, cb.getInvocationCount());
assertEquals(2, method.invoke(a, null));
assertEquals(2, cb.getInvocationCount());
__toVersion__Delayed(2);
method = getMethod(a, "getValue3");
assertEquals("getValue3", method.getName());
assertEquals(2, cb.getInvocationCount());
assertEquals(3, method.invoke(a, null));
assertEquals(3, cb.getInvocationCount());
}
private A createEnhancer(Callback cb) {
Enhancer enhancer = new Enhancer();
enhancer.setSuperclass(AImpl.class);
enhancer.setCallback(cb);
final A a = (A) enhancer.create();
return a;
}
@Test
public void accessNewMethodOnProxyCreatedAfterSwap() throws IllegalAccessException, IllegalArgumentException,
InvocationTargetException, IOException {
assert __version__() == 0;
SerializableNoOp cb = new SerializableNoOp();
A a = createEnhancer(cb);
assertEquals(0, cb.getInvocationCount());
assertEquals(1, a.getValue1());
assertEquals(1, cb.getInvocationCount());
__toVersion__Delayed(1);
a = createEnhancer(cb);
Method method = getMethod(a, "getValue2");
assertEquals("getValue2", method.getName());
assertEquals(1, cb.getInvocationCount());
assertEquals(2, method.invoke(a, null));
assertEquals(2, cb.getInvocationCount());
}
private Method getMethod(Object a, String methodName) {
Method[] declaredMethods = a.getClass().getMethods();
Method m = null;
for (Method method : declaredMethods) {
if (method.getName().equals(methodName))
m = method;
}
if (m == null) {
fail(a.getClass().getSimpleName() + " does not have method " + methodName);
}
return m;
}
}