/* * Copyright 2012-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.boot.test.mock.mockito; import java.lang.annotation.Annotation; import java.lang.reflect.Field; import java.util.LinkedHashSet; import java.util.Set; import org.mockito.Captor; import org.mockito.MockitoAnnotations; import org.springframework.context.ApplicationContext; import org.springframework.test.context.TestContext; import org.springframework.test.context.TestExecutionListener; import org.springframework.test.context.support.AbstractTestExecutionListener; import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils.FieldCallback; /** * {@link TestExecutionListener} to trigger {@link MockitoAnnotations#initMocks(Object)} * when {@link MockBean @MockBean} annotations are used. Primarily to allow {@link Captor} * annotations. * * @author Phillip Webb * @since 1.4.2 */ public class MockitoTestExecutionListener extends AbstractTestExecutionListener { @Override public void prepareTestInstance(TestContext testContext) throws Exception { if (hasMockitoAnnotations(testContext)) { MockitoAnnotations.initMocks(testContext.getTestInstance()); } injectFields(testContext); } private boolean hasMockitoAnnotations(TestContext testContext) { MockitoAnnotationCollection collector = new MockitoAnnotationCollection(); ReflectionUtils.doWithFields(testContext.getTestClass(), collector); return collector.hasAnnotations(); } private void injectFields(TestContext testContext) { DefinitionsParser parser = new DefinitionsParser(); parser.parse(testContext.getTestClass()); if (!parser.getDefinitions().isEmpty()) { injectFields(testContext, parser); } } private void injectFields(TestContext testContext, DefinitionsParser parser) { ApplicationContext applicationContext = testContext.getApplicationContext(); MockitoPostProcessor postProcessor = applicationContext .getBean(MockitoPostProcessor.class); for (Definition definition : parser.getDefinitions()) { Field field = parser.getField(definition); if (field != null) { postProcessor.inject(field, testContext.getTestInstance(), definition); } } } /** * {@link FieldCallback} to collect Mockito annotations. */ private static class MockitoAnnotationCollection implements FieldCallback { private final Set<Annotation> annotations = new LinkedHashSet<>(); @Override public void doWith(Field field) throws IllegalArgumentException, IllegalAccessException { for (Annotation annotation : field.getDeclaredAnnotations()) { if (annotation.annotationType().getName().startsWith("org.mockito")) { this.annotations.add(annotation); } } } public boolean hasAnnotations() { return !this.annotations.isEmpty(); } } }