package fr.openwide.core.spring.util; import java.lang.annotation.Annotation; import java.util.LinkedHashSet; import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider; import org.springframework.core.type.filter.AnnotationTypeFilter; import org.springframework.core.type.filter.AssignableTypeFilter; import org.springframework.core.type.filter.TypeFilter; public final class ReflectionUtils { private static final Logger LOGGER = LoggerFactory.getLogger(ReflectionUtils.class); public static <T> Set<Class<? extends T>> findAssignableClasses(String rootPackage, Class<T> clazz) { return findClasses(rootPackage, clazz, new AssignableTypeFilter(clazz)); } public static Set<Class<? extends Object>> findAnnotatedClasses(String rootPackage, Class<? extends Annotation> annotationType) { return findClasses(rootPackage, Object.class, new AnnotationTypeFilter(annotationType)); } @SuppressWarnings("unchecked") private static <T> Set<Class<? extends T>> findClasses(String rootPackage, Class<T> clazz, TypeFilter filter) { ClassPathScanningCandidateComponentProvider scanner = new ClassPathScanningCandidateComponentProvider(false); scanner.addIncludeFilter(filter); Set<BeanDefinition> beanDefinitions = scanner.findCandidateComponents(rootPackage); Set<Class<? extends T>> classes = new LinkedHashSet<Class<? extends T>>(); for (BeanDefinition beanDefinition : beanDefinitions) { try { classes.add((Class<? extends T>) Class.forName(beanDefinition.getBeanClassName())); } catch (ClassNotFoundException e) { LOGGER.warn("Class not found: " + beanDefinition.getBeanClassName()); } } return classes; } private ReflectionUtils() { } }