/*
* Copyright 2012-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.test.mock.mockito;
import java.beans.PropertyDescriptor;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.springframework.aop.scope.ScopedProxyUtils;
import org.springframework.aop.support.AopUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.PropertyValues;
import org.springframework.beans.factory.BeanClassLoaderAware;
import org.springframework.beans.factory.BeanCreationException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.ConstructorArgumentValues;
import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder;
import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessorAdapter;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanNameGenerator;
import org.springframework.beans.factory.support.DefaultBeanNameGenerator;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.ConfigurationClassPostProcessor;
import org.springframework.core.Conventions;
import org.springframework.core.Ordered;
import org.springframework.core.PriorityOrdered;
import org.springframework.core.ResolvableType;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.ReflectionUtils.FieldCallback;
import org.springframework.util.StringUtils;
/**
* A {@link BeanFactoryPostProcessor} used to register and inject
* {@link MockBean @MockBeans} with the {@link ApplicationContext}. An initial set of
* definitions can be passed to the processor with additional definitions being
* automatically created from {@code @Configuration} classes that use
* {@link MockBean @MockBean}.
*
* @author Phillip Webb
* @author Andy Wilkinson
* @author Stephane Nicoll
* @since 1.4.0
*/
public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAdapter
implements BeanClassLoaderAware, BeanFactoryAware, BeanFactoryPostProcessor,
Ordered {
private static final String FACTORY_BEAN_OBJECT_TYPE = "factoryBeanObjectType";
private static final String BEAN_NAME = MockitoPostProcessor.class.getName();
private static final String CONFIGURATION_CLASS_ATTRIBUTE = Conventions
.getQualifiedAttributeName(ConfigurationClassPostProcessor.class,
"configurationClass");
private final Set<Definition> definitions;
private ClassLoader classLoader;
private BeanFactory beanFactory;
private final BeanNameGenerator beanNameGenerator = new DefaultBeanNameGenerator();
private final MockitoBeans mockitoBeans = new MockitoBeans();
private Map<Definition, String> beanNameRegistry = new HashMap<>();
private Map<Field, RegisteredField> fieldRegistry = new HashMap<>();
private Map<String, SpyDefinition> spies = new HashMap<>();
/**
* Create a new {@link MockitoPostProcessor} instance with the given initial
* definitions.
* @param definitions the initial definitions
*/
public MockitoPostProcessor(Set<Definition> definitions) {
this.definitions = definitions;
}
@Override
public void setBeanClassLoader(ClassLoader classLoader) {
this.classLoader = classLoader;
}
@Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
Assert.isInstanceOf(ConfigurableListableBeanFactory.class, beanFactory,
"Mock beans can only be used with a ConfigurableListableBeanFactory");
this.beanFactory = beanFactory;
}
@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)
throws BeansException {
Assert.isInstanceOf(BeanDefinitionRegistry.class, beanFactory,
"@MockBean can only be used on bean factories that "
+ "implement BeanDefinitionRegistry");
postProcessBeanFactory(beanFactory, (BeanDefinitionRegistry) beanFactory);
}
private void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory,
BeanDefinitionRegistry registry) {
beanFactory.registerSingleton(MockitoBeans.class.getName(), this.mockitoBeans);
DefinitionsParser parser = new DefinitionsParser(this.definitions);
for (Class<?> configurationClass : getConfigurationClasses(beanFactory)) {
parser.parse(configurationClass);
}
Set<Definition> definitions = parser.getDefinitions();
for (Definition definition : definitions) {
Field field = parser.getField(definition);
register(beanFactory, registry, definition, field);
}
}
private Set<Class<?>> getConfigurationClasses(
ConfigurableListableBeanFactory beanFactory) {
Set<Class<?>> configurationClasses = new LinkedHashSet<>();
for (BeanDefinition beanDefinition : getConfigurationBeanDefinitions(beanFactory)
.values()) {
configurationClasses.add(ClassUtils.resolveClassName(
beanDefinition.getBeanClassName(), this.classLoader));
}
return configurationClasses;
}
private Map<String, BeanDefinition> getConfigurationBeanDefinitions(
ConfigurableListableBeanFactory beanFactory) {
Map<String, BeanDefinition> definitions = new LinkedHashMap<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
BeanDefinition definition = beanFactory.getBeanDefinition(beanName);
if (definition.getAttribute(CONFIGURATION_CLASS_ATTRIBUTE) != null) {
definitions.put(beanName, definition);
}
}
return definitions;
}
private void register(ConfigurableListableBeanFactory beanFactory,
BeanDefinitionRegistry registry, Definition definition, Field field) {
if (definition instanceof MockDefinition) {
registerMock(beanFactory, registry, (MockDefinition) definition, field);
}
else if (definition instanceof SpyDefinition) {
registerSpy(beanFactory, registry, (SpyDefinition) definition, field);
}
}
private void registerMock(ConfigurableListableBeanFactory beanFactory,
BeanDefinitionRegistry registry, MockDefinition definition, Field field) {
RootBeanDefinition beanDefinition = createBeanDefinition(definition);
String beanName = getBeanName(beanFactory, registry, definition, beanDefinition);
String transformedBeanName = BeanFactoryUtils.transformedBeanName(beanName);
beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(1,
beanName);
if (registry.containsBeanDefinition(transformedBeanName)) {
registry.removeBeanDefinition(transformedBeanName);
}
registry.registerBeanDefinition(transformedBeanName, beanDefinition);
Object mock = createMock(definition, beanName);
beanFactory.registerSingleton(transformedBeanName, mock);
this.mockitoBeans.add(mock);
this.beanNameRegistry.put(definition, beanName);
if (field != null) {
this.fieldRegistry.put(field, new RegisteredField(definition, beanName));
}
}
private RootBeanDefinition createBeanDefinition(MockDefinition mockDefinition) {
RootBeanDefinition definition = new RootBeanDefinition(
mockDefinition.getTypeToMock().resolve());
definition.setTargetType(mockDefinition.getTypeToMock());
definition.setFactoryBeanName(BEAN_NAME);
definition.setFactoryMethodName("createMock");
definition.getConstructorArgumentValues().addIndexedArgumentValue(0,
mockDefinition);
if (mockDefinition.getQualifier() != null) {
mockDefinition.getQualifier().applyTo(definition);
}
return definition;
}
/**
* Factory method used by defined beans to actually create the mock.
* @param mockDefinition the mock definition
* @param name the bean name
* @return the mock instance
*/
protected final Object createMock(MockDefinition mockDefinition, String name) {
return mockDefinition.createMock(name + " bean");
}
private String getBeanName(ConfigurableListableBeanFactory beanFactory,
BeanDefinitionRegistry registry, MockDefinition mockDefinition,
RootBeanDefinition beanDefinition) {
if (StringUtils.hasLength(mockDefinition.getName())) {
return mockDefinition.getName();
}
Set<String> existingBeans = findCandidateBeans(beanFactory, mockDefinition);
if (existingBeans.isEmpty()) {
return this.beanNameGenerator.generateBeanName(beanDefinition, registry);
}
if (existingBeans.size() == 1) {
return existingBeans.iterator().next();
}
throw new IllegalStateException(
"Unable to register mock bean " + mockDefinition.getTypeToMock()
+ " expected a single matching bean to replace but found "
+ existingBeans);
}
private void registerSpy(ConfigurableListableBeanFactory beanFactory,
BeanDefinitionRegistry registry, SpyDefinition definition, Field field) {
String[] existingBeans = getExistingBeans(beanFactory, definition.getTypeToSpy());
if (ObjectUtils.isEmpty(existingBeans)) {
createSpy(registry, definition, field);
}
else {
registerSpies(registry, definition, field, existingBeans);
}
}
private Set<String> findCandidateBeans(ConfigurableListableBeanFactory beanFactory,
MockDefinition mockDefinition) {
QualifierDefinition qualifier = mockDefinition.getQualifier();
Set<String> candidates = new TreeSet<>();
for (String candidate : getExistingBeans(beanFactory,
mockDefinition.getTypeToMock())) {
if (qualifier == null || qualifier.matches(beanFactory, candidate)) {
candidates.add(candidate);
}
}
return candidates;
}
private String[] getExistingBeans(ConfigurableListableBeanFactory beanFactory,
ResolvableType type) {
Set<String> beans = new LinkedHashSet<>(
Arrays.asList(beanFactory.getBeanNamesForType(type)));
String resolvedTypeName = type.resolve(Object.class).getName();
for (String beanName : beanFactory.getBeanNamesForType(FactoryBean.class)) {
beanName = BeanFactoryUtils.transformedBeanName(beanName);
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
if (resolvedTypeName
.equals(beanDefinition.getAttribute(FACTORY_BEAN_OBJECT_TYPE))) {
beans.add(beanName);
}
}
for (Iterator<String> iterator = beans.iterator(); iterator.hasNext();) {
if (isScopedTarget(iterator.next())) {
iterator.remove();
}
}
return beans.toArray(new String[beans.size()]);
}
private boolean isScopedTarget(String beanName) {
try {
return ScopedProxyUtils.isScopedTarget(beanName);
}
catch (Throwable ex) {
return false;
}
}
private void createSpy(BeanDefinitionRegistry registry, SpyDefinition definition,
Field field) {
RootBeanDefinition beanDefinition = new RootBeanDefinition(
definition.getTypeToSpy().resolve());
String beanName = this.beanNameGenerator.generateBeanName(beanDefinition,
registry);
registry.registerBeanDefinition(beanName, beanDefinition);
registerSpy(definition, field, beanName);
}
private void registerSpies(BeanDefinitionRegistry registry, SpyDefinition definition,
Field field, String[] existingBeans) {
try {
registerSpy(definition, field,
determineBeanName(existingBeans, definition, registry));
}
catch (RuntimeException ex) {
throw new IllegalStateException(
"Unable to register spy bean " + definition.getTypeToSpy(), ex);
}
}
private String determineBeanName(String[] existingBeans, SpyDefinition definition,
BeanDefinitionRegistry registry) {
if (StringUtils.hasText(definition.getName())) {
return definition.getName();
}
if (existingBeans.length == 1) {
return existingBeans[0];
}
return determinePrimaryCandidate(registry, existingBeans,
definition.getTypeToSpy());
}
private String determinePrimaryCandidate(BeanDefinitionRegistry registry,
String[] candidateBeanNames, ResolvableType type) {
String primaryBeanName = null;
for (String candidateBeanName : candidateBeanNames) {
BeanDefinition beanDefinition = registry.getBeanDefinition(candidateBeanName);
if (beanDefinition.isPrimary()) {
if (primaryBeanName != null) {
throw new NoUniqueBeanDefinitionException(type.resolve(),
candidateBeanNames.length,
"more than one 'primary' bean found among candidates: "
+ Arrays.asList(candidateBeanNames));
}
primaryBeanName = candidateBeanName;
}
}
return primaryBeanName;
}
private void registerSpy(SpyDefinition definition, Field field, String beanName) {
this.spies.put(beanName, definition);
this.beanNameRegistry.put(definition, beanName);
if (field != null) {
this.fieldRegistry.put(field, new RegisteredField(definition, beanName));
}
}
protected Object createSpyIfNecessary(Object bean, String beanName)
throws BeansException {
SpyDefinition definition = this.spies.get(beanName);
if (definition != null) {
bean = definition.createSpy(beanName, bean);
}
return bean;
}
@Override
public PropertyValues postProcessPropertyValues(PropertyValues pvs,
PropertyDescriptor[] pds, final Object bean, String beanName)
throws BeansException {
ReflectionUtils.doWithFields(bean.getClass(), new FieldCallback() {
@Override
public void doWith(Field field)
throws IllegalArgumentException, IllegalAccessException {
postProcessField(bean, field);
}
});
return pvs;
}
private void postProcessField(Object bean, Field field) {
RegisteredField registered = this.fieldRegistry.get(field);
if (registered != null && StringUtils.hasLength(registered.getBeanName())) {
inject(field, bean, registered.getBeanName(), registered.getDefinition());
}
}
void inject(Field field, Object target, Definition definition) {
String beanName = this.beanNameRegistry.get(definition);
Assert.state(StringUtils.hasLength(beanName),
"No bean found for definition " + definition);
inject(field, target, beanName, definition);
}
private void inject(Field field, Object target, String beanName,
Definition definition) {
try {
field.setAccessible(true);
Assert.state(ReflectionUtils.getField(field, target) == null,
"The field " + field + " cannot have an existing value");
Object bean = this.beanFactory.getBean(beanName, field.getType());
if (definition.isProxyTargetAware() && isAopProxy(bean)) {
MockitoAopProxyTargetInterceptor.applyTo(bean);
}
ReflectionUtils.setField(field, target, bean);
}
catch (Throwable ex) {
throw new BeanCreationException("Could not inject field: " + field, ex);
}
}
private boolean isAopProxy(Object object) {
try {
return AopUtils.isAopProxy(object);
}
catch (Throwable ex) {
return false;
}
}
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE - 10;
}
/**
* Register the processor with a {@link BeanDefinitionRegistry}. Not required when
* using the {@link SpringRunner} as registration is automatic.
* @param registry the bean definition registry
*/
public static void register(BeanDefinitionRegistry registry) {
register(registry, null);
}
/**
* Register the processor with a {@link BeanDefinitionRegistry}. Not required when
* using the {@link SpringRunner} as registration is automatic.
* @param registry the bean definition registry
* @param definitions the initial mock/spy definitions
*/
public static void register(BeanDefinitionRegistry registry,
Set<Definition> definitions) {
register(registry, MockitoPostProcessor.class, definitions);
}
/**
* Register the processor with a {@link BeanDefinitionRegistry}. Not required when
* using the {@link SpringRunner} as registration is automatic.
* @param registry the bean definition registry
* @param postProcessor the post processor class to register
* @param definitions the initial mock/spy definitions
*/
@SuppressWarnings("unchecked")
public static void register(BeanDefinitionRegistry registry,
Class<? extends MockitoPostProcessor> postProcessor,
Set<Definition> definitions) {
SpyPostProcessor.register(registry);
BeanDefinition definition = getOrAddBeanDefinition(registry, postProcessor);
ValueHolder constructorArg = definition.getConstructorArgumentValues()
.getIndexedArgumentValue(0, Set.class);
Set<Definition> existing = (Set<Definition>) constructorArg.getValue();
if (definitions != null) {
existing.addAll(definitions);
}
}
private static BeanDefinition getOrAddBeanDefinition(BeanDefinitionRegistry registry,
Class<? extends MockitoPostProcessor> postProcessor) {
if (!registry.containsBeanDefinition(BEAN_NAME)) {
RootBeanDefinition definition = new RootBeanDefinition(postProcessor);
definition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
ConstructorArgumentValues constructorArguments = definition
.getConstructorArgumentValues();
constructorArguments.addIndexedArgumentValue(0,
new LinkedHashSet<MockDefinition>());
registry.registerBeanDefinition(BEAN_NAME, definition);
return definition;
}
return registry.getBeanDefinition(BEAN_NAME);
}
/**
* {@link BeanPostProcessor} to handle {@link SpyBean} definitions. Registered as a
* separate processor so that it can be ordered above AOP post processors.
*/
static class SpyPostProcessor extends InstantiationAwareBeanPostProcessorAdapter
implements PriorityOrdered {
private static final String BEAN_NAME = SpyPostProcessor.class.getName();
private final MockitoPostProcessor mockitoPostProcessor;
SpyPostProcessor(MockitoPostProcessor mockitoPostProcessor) {
this.mockitoPostProcessor = mockitoPostProcessor;
}
@Override
public int getOrder() {
return Ordered.HIGHEST_PRECEDENCE;
}
@Override
public Object getEarlyBeanReference(Object bean, String beanName)
throws BeansException {
return createSpyIfNecessary(bean, beanName);
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName)
throws BeansException {
if (bean instanceof FactoryBean) {
return bean;
}
return createSpyIfNecessary(bean, beanName);
}
private Object createSpyIfNecessary(Object bean, String beanName) {
return this.mockitoPostProcessor.createSpyIfNecessary(bean, beanName);
}
public static void register(BeanDefinitionRegistry registry) {
if (!registry.containsBeanDefinition(BEAN_NAME)) {
RootBeanDefinition definition = new RootBeanDefinition(
SpyPostProcessor.class);
definition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
ConstructorArgumentValues constructorArguments = definition
.getConstructorArgumentValues();
constructorArguments.addIndexedArgumentValue(0,
new RuntimeBeanReference(MockitoPostProcessor.BEAN_NAME));
registry.registerBeanDefinition(BEAN_NAME, definition);
}
}
}
/**
* A registered field item.
*/
private static class RegisteredField {
private final Definition definition;
private final String beanName;
RegisteredField(Definition definition, String beanName) {
this.definition = definition;
this.beanName = beanName;
}
public Definition getDefinition() {
return this.definition;
}
public String getBeanName() {
return this.beanName;
}
}
}