/* * Copyright (c) 2016 Mockito contributors * This program is made available under the terms of the MIT License. */ package org.mockito.internal.creation.bytebuddy; import net.bytebuddy.asm.Advice; import net.bytebuddy.implementation.bind.annotation.Argument; import net.bytebuddy.implementation.bind.annotation.This; import net.bytebuddy.implementation.bytecode.assign.Assigner; import org.mockito.exceptions.base.MockitoException; import org.mockito.internal.exceptions.stacktrace.ConditionalStackTraceFilter; import org.mockito.internal.invocation.SerializableMethod; import org.mockito.internal.util.concurrent.WeakConcurrentMap; import java.io.IOException; import java.io.ObjectInputStream; import java.io.Serializable; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.concurrent.Callable; import static org.mockito.internal.creation.bytebuddy.InlineByteBuddyMockMaker.hideRecursiveCall; public class MockMethodAdvice extends MockMethodDispatcher { final WeakConcurrentMap<Object, MockMethodInterceptor> interceptors; private final String identifier; private final SelfCallInfo selfCallInfo = new SelfCallInfo(); public MockMethodAdvice(WeakConcurrentMap<Object, MockMethodInterceptor> interceptors, String identifier) { this.interceptors = interceptors; this.identifier = identifier; } @SuppressWarnings("unused") @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class) private static Callable<?> enter(@Identifier String identifier, @Advice.This Object mock, @Advice.Origin Method origin, @Advice.AllArguments Object[] arguments) throws Throwable { MockMethodDispatcher dispatcher = MockMethodDispatcher.get(identifier, mock); if (dispatcher == null || !dispatcher.isMocked(mock) || !dispatcher.isOverridden(mock, origin)) { return null; } else { return dispatcher.handle(mock, origin, arguments); } } @SuppressWarnings({"unused", "UnusedAssignment"}) @Advice.OnMethodExit private static void exit(@Advice.Return(readOnly = false, typing = Assigner.Typing.DYNAMIC) Object returned, @Advice.Enter Callable<?> mocked) throws Throwable { if (mocked != null) { returned = mocked.call(); } } @Override public Callable<?> handle(Object instance, Method origin, Object[] arguments) throws Throwable { MockMethodInterceptor interceptor = interceptors.get(instance); if (interceptor == null) { return null; } InterceptedInvocation.SuperMethod superMethod; if (instance instanceof Serializable) { superMethod = new SerializableSuperMethodCall(identifier, origin, instance, arguments); } else { superMethod = new SuperMethodCall(selfCallInfo, origin, instance, arguments); } return new ReturnValueWrapper(interceptor.doIntercept(instance, origin, arguments, superMethod)); } @Override public boolean isMock(Object instance) { return interceptors.containsKey(instance); } @Override public boolean isMocked(Object instance) { return selfCallInfo.checkSuperCall(instance) && isMock(instance); } @Override public boolean isOverridden(Object instance, Method origin) { Class<?> currentType = instance.getClass(); do { try { return origin.equals(currentType.getDeclaredMethod(origin.getName(), origin.getParameterTypes())); } catch (NoSuchMethodException ignored) { currentType = currentType.getSuperclass(); } } while (currentType != null); return true; } private static class SuperMethodCall implements InterceptedInvocation.SuperMethod { private final SelfCallInfo selfCallInfo; private final Method origin; private final Object instance; private final Object[] arguments; private SuperMethodCall(SelfCallInfo selfCallInfo, Method origin, Object instance, Object[] arguments) { this.selfCallInfo = selfCallInfo; this.origin = origin; this.instance = instance; this.arguments = arguments; } @Override public boolean isInvokable() { return true; } @Override public Object invoke() throws Throwable { if (!Modifier.isPublic(origin.getDeclaringClass().getModifiers() & origin.getModifiers())) { origin.setAccessible(true); } selfCallInfo.set(instance); return tryInvoke(origin, instance, arguments); } } private static class SerializableSuperMethodCall implements InterceptedInvocation.SuperMethod { private final String identifier; private final SerializableMethod origin; private final Object instance; private final Object[] arguments; private SerializableSuperMethodCall(String identifier, Method origin, Object instance, Object[] arguments) { this.origin = new SerializableMethod(origin); this.identifier = identifier; this.instance = instance; this.arguments = arguments; } @Override public boolean isInvokable() { return true; } @Override public Object invoke() throws Throwable { Method method = origin.getJavaMethod(); if (!Modifier.isPublic(method.getDeclaringClass().getModifiers() & method.getModifiers())) { method.setAccessible(true); } MockMethodDispatcher mockMethodDispatcher = MockMethodDispatcher.get(identifier, instance); if (!(mockMethodDispatcher instanceof MockMethodAdvice)) { throw new MockitoException("Unexpected dispatcher for advice-based super call"); } ((MockMethodAdvice) mockMethodDispatcher).selfCallInfo.set(instance); return tryInvoke(method, instance, arguments); } } private static Object tryInvoke(Method origin, Object instance, Object[] arguments) throws Throwable { try { return origin.invoke(instance, arguments); } catch (InvocationTargetException exception) { Throwable cause = exception.getCause(); new ConditionalStackTraceFilter().filter(hideRecursiveCall(cause, new Throwable().getStackTrace().length, origin.getDeclaringClass())); throw cause; } } private static class ReturnValueWrapper implements Callable<Object> { private final Object returned; private ReturnValueWrapper(Object returned) { this.returned = returned; } @Override public Object call() { return returned; } } private static class SelfCallInfo extends ThreadLocal<Object> { boolean checkSuperCall(Object value) { Object current = get(); if (current == value) { set(null); return false; } else { return true; } } } @Retention(RetentionPolicy.RUNTIME) @interface Identifier { } static class ForHashCode { @SuppressWarnings("unused") @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class) private static boolean enter(@Identifier String id, @Advice.This Object self) { MockMethodDispatcher dispatcher = MockMethodDispatcher.get(id, self); return dispatcher != null && dispatcher.isMock(self); } @SuppressWarnings({"unused", "UnusedAssignment"}) @Advice.OnMethodExit private static void enter(@Advice.This Object self, @Advice.Return(readOnly = false) int hashCode, @Advice.Enter boolean skipped) { if (skipped) { hashCode = System.identityHashCode(self); } } } static class ForEquals { @SuppressWarnings("unused") @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class) private static boolean enter(@Identifier String identifier, @Advice.This Object self) { MockMethodDispatcher dispatcher = MockMethodDispatcher.get(identifier, self); return dispatcher != null && dispatcher.isMock(self); } @SuppressWarnings({"unused", "UnusedAssignment"}) @Advice.OnMethodExit private static void enter(@Advice.This Object self, @Advice.Argument(0) Object other, @Advice.Return(readOnly = false) boolean equals, @Advice.Enter boolean skipped) { if (skipped) { equals = self == other; } } } public static class ForReadObject { @SuppressWarnings("unused") public static void doReadObject(@Identifier String identifier, @This MockAccess thiz, @Argument(0) ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException { objectInputStream.defaultReadObject(); MockMethodAdvice mockMethodAdvice = (MockMethodAdvice) MockMethodDispatcher.get(identifier, thiz); if (mockMethodAdvice != null) { mockMethodAdvice.interceptors.put(thiz, thiz.getMockitoInterceptor()); } } } }