package com.blogspot.toomuchcoding.common.testng; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.mockito.Spy; import org.mockito.exceptions.base.MockitoException; import org.mockito.internal.util.reflection.Fields; import org.testng.IInvokedMethod; import org.testng.ITestResult; import java.lang.reflect.Field; import java.util.Collection; import java.util.HashSet; import java.util.Set; import static org.mockito.internal.util.reflection.Fields.annotatedBy; public class MockitoAfterTestNGMethod { public void applyFor(IInvokedMethod method, ITestResult testResult) { Mockito.validateMockitoUsage(); if (method.isTestMethod()) { resetMocks(testResult.getInstance()); } } private void resetMocks(Object instance) { Mockito.reset(instanceMocksOf(instance).toArray()); } @SuppressWarnings({"deprecation", "unchecked"}) private Collection<Object> instanceMocksOf(Object instance) { return Fields.allDeclaredFieldsOf(instance) .filter(annotatedBy(Mock.class, Spy.class, MockitoAnnotations.Mock.class)) .notNull() .assignedValues(); } private Set<Object> instanceMocksIn(Object instance, Class<?> clazz) { Set<Object> instanceMocks = new HashSet<Object>(); Field[] declaredFields = clazz.getDeclaredFields(); for (Field declaredField : declaredFields) { if (declaredField.isAnnotationPresent(Mock.class) || declaredField.isAnnotationPresent(Spy.class)) { declaredField.setAccessible(true); try { Object fieldValue = declaredField.get(instance); if (fieldValue != null) { instanceMocks.add(fieldValue); } } catch (IllegalAccessException e) { throw new MockitoException("Could not access field " + declaredField.getName()); } } } return instanceMocks; } }