package com.mogujie.trade.tsharding.route.orm; import com.mogujie.trade.db.DataSourceLookup; import com.mogujie.trade.db.ReadWriteSplittingDataSource; import com.mogujie.trade.tsharding.route.orm.base.*; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.session.Configuration; import org.apache.ibatis.session.SqlSessionFactory; import org.mybatis.spring.SqlSessionFactoryBean; import org.springframework.beans.BeansException; import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.core.io.Resource; import java.io.IOException; import java.lang.reflect.*; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; /** * Tsharding MybatisMapper的扫描类,负责将Mapper接口与对应的xml配置文件整合,绑定设定的数据源,注入到Spring Context中。 * * @author qigong */ public class MapperScannerWithSharding implements BeanFactoryPostProcessor, InitializingBean { public static DataSourceLookup dataSourceLookup; private String packageName; private Resource[] mapperLocations; private String[] mapperPacakages; private SqlSessionFactoryLookup sqlSessionFactoryLookup; public static DataSourceLookup getDataSourceLookup() { return dataSourceLookup; } @Override public void afterPropertiesSet() throws Exception { this.initMapperPackage(); } private void initMapperPackage() throws IOException { this.mapperPacakages = packageName.split(","); } @Override public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { this.dataSourceLookup = beanFactory.getBean(DataSourceLookup.class); try { this.initSqlSessionFactories(beanFactory); } catch (Exception e) { throw new RuntimeException(e); } ClassPathScanHandler scanner = new ClassPathScanHandler(); Set<Class<?>> mapperClasses = new HashSet<>(); for (String mapperPackage : this.mapperPacakages) { Set<Class<?>> classes = scanner.getPackageAllClasses(mapperPackage.trim(), false); mapperClasses.addAll(classes); } for (Class<?> clazz : mapperClasses) { if (isMapper(clazz)) { Object mapper = this.newMapper(clazz); beanFactory.registerSingleton(Character.toLowerCase(clazz.getSimpleName().charAt(0)) + clazz.getSimpleName().substring(1), mapper); } } } private void initSqlSessionFactories(ConfigurableListableBeanFactory beanFactory) throws Exception { Map<String, SqlSessionFactory> sqlSessionFactories = new HashMap<>(this.dataSourceLookup.getMapping().size()); ReadWriteSplittingDataSource defaultDataSource = null; SqlSessionFactory defaultSqlSessionFactory = null; for (ReadWriteSplittingDataSource dataSource : this.dataSourceLookup.getMapping().values()) { SqlSessionFactoryBean sessionFactoryBean = new SqlSessionFactoryBean(); sessionFactoryBean.setMapperLocations(mapperLocations); sessionFactoryBean.setDataSource(dataSource); sessionFactoryBean.setTypeAliasesPackage(this.packageName + ".domain.entity"); // init 初始化所有sql对应的元数据、资源(sqlNode, sqlSource, mappedStatement)等 sessionFactoryBean.afterPropertiesSet(); if (defaultDataSource == null) { //第一个 defaultDataSource = dataSource; defaultSqlSessionFactory = sessionFactoryBean.getObject(); } else { SqlSessionFactory newSqlSessionFactory = sessionFactoryBean.getObject(); Field conf = newSqlSessionFactory.getClass().getDeclaredField("configuration"); conf.setAccessible(true); Configuration newConfiguration = (Configuration) conf.get(newSqlSessionFactory); Field mappedStatementField = newConfiguration.getClass().getDeclaredField("mappedStatements"); //去掉final修饰符 Field modifiersField = Field.class.getDeclaredField("modifiers"); modifiersField.setAccessible(true); modifiersField.setInt( mappedStatementField, mappedStatementField.getModifiers() & ~Modifier.FINAL); mappedStatementField.setAccessible(true); //后续的元数据复用 Configuration defaultConfiguration = defaultSqlSessionFactory.getConfiguration(); Map<String, MappedStatement> reUsedMappedStatement = (Map) mappedStatementField.get(defaultConfiguration); mappedStatementField.set(newConfiguration, reUsedMappedStatement); } beanFactory.registerSingleton(dataSource.getName() + "SqlSessionFactory", sessionFactoryBean); sqlSessionFactories.put(dataSource.getName(), sessionFactoryBean.getObject()); defaultSqlSessionFactory = sessionFactoryBean.getObject(); } this.sqlSessionFactoryLookup = new SqlSessionFactoryLookup(sqlSessionFactories); } private boolean isMapper(Class<?> clazz) { if (clazz.isInterface()) { return true; } return false; } private Object newMapper(final Class<?> clazz) { final Invoker invoker = new TShardingRoutingInvokeFactory(sqlSessionFactoryLookup).newInvoker(clazz); return Proxy.newProxyInstance(clazz.getClassLoader(), new Class[]{clazz}, new InvocationHandler() { @Override public Object invoke(Object proxy, final Method method, final Object[] args) throws Throwable { return invoker.invoke(new DefaultInvocation(method, args)); } }); } /** * 注入packageName配置 * * @param packageName */ public void setPackageName(String packageName) { this.packageName = packageName; } /** * 注入mapperLocations配置 * * @param mapperLocations */ public void setMapperLocations(Resource[] mapperLocations) { this.mapperLocations = mapperLocations; } }