package org.unbrokendome.eventbus.proxy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import java.util.stream.Stream;
public class EventSubscriberBeanPostProcessor implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
private static final Logger logger = LoggerFactory.getLogger(EventSubscriberBeanPostProcessor.class);
private ConfigurableListableBeanFactory beanFactory;
private final SubscriberScanner subscriberScanner;
private final SubscriberProxyClassGenerator proxyClassGenerator;
public EventSubscriberBeanPostProcessor() {
this(new ReflectiveSubscriberScanner(), new CglibSubscriberProxyClassGenerator());
}
EventSubscriberBeanPostProcessor(SubscriberScanner subscriberScanner,
SubscriberProxyClassGenerator proxyClassGenerator) {
this.subscriberScanner = subscriberScanner;
this.proxyClassGenerator = proxyClassGenerator;
}
@Override
public void setBeanFactory(BeanFactory beanFactory) {
if (!(beanFactory instanceof ConfigurableListableBeanFactory)) {
throw new IllegalArgumentException(EventSubscriberBeanPostProcessor.class.getSimpleName()
+ " requires a ConfigurableListableBeanFactory");
}
this.beanFactory = (ConfigurableListableBeanFactory) beanFactory;
}
@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
Stream.of(beanFactory.getBeanDefinitionNames())
.map(name -> new BeanNameAndType(name, beanFactory.getType(name)))
.peek(bean -> logger.trace(
"Inspecting bean {} of type {} for event subscriber methods",
bean.getName(), bean.getType()))
.flatMap(bean ->
subscriberScanner.scanForSubscriberMethods(bean.getName(), bean.getType()))
.map(this::makeProxyBeanDefinition)
.forEach(bean -> registry.registerBeanDefinition(bean.getName(), bean.getDefinition()));
}
@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
}
private BeanNameAndDefinition makeProxyBeanDefinition(EventSubscriberInfo subscriberInfo) {
return new BeanNameAndDefinition(
makeProxyBeanName(subscriberInfo),
createProxyBeanDefinition(subscriberInfo));
}
private String makeProxyBeanName(EventSubscriberInfo subscriberInfo) {
return subscriberInfo.getBeanName() + "##eventSubscriber_"
+ subscriberInfo.getSubscriberMethodName() + "_"
+ subscriberInfo.getEventType().getSimpleName();
}
private BeanDefinition createProxyBeanDefinition(EventSubscriberInfo subscriberInfo) {
Class<?> proxyClass = proxyClassGenerator.generate(subscriberInfo, beanFactory.getBeanClassLoader());
return BeanDefinitionBuilder.genericBeanDefinition(proxyClass)
.addConstructorArgReference(subscriberInfo.getBeanName())
.getBeanDefinition();
}
private static class BeanNameAndType {
private final String name;
private final Class<?> type;
private BeanNameAndType(String name, Class<?> type) {
this.name = name;
this.type = type;
}
public String getName() {
return name;
}
public Class<?> getType() {
return type;
}
}
private static class BeanNameAndDefinition {
private final String name;
private final BeanDefinition definition;
private BeanNameAndDefinition(String name, BeanDefinition definition) {
this.name = name;
this.definition = definition;
}
public String getName() {
return name;
}
public BeanDefinition getDefinition() {
return definition;
}
}
}