package com.mogujie.trade.tsharding.route.orm; import com.mogujie.trade.tsharding.annotation.ShardingExtensionMethod; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.session.Configuration; import org.apache.ibatis.session.SqlSession; import java.lang.reflect.Method; import java.util.*; /** * Mapper处理主要逻辑,最关键的一个类 * <p/> * <p> * 参考项目地址 : <a href="https://github.com/abel533/Mapper" * target="_blank">https://github.com/abel533/Mapper</a> * </p> * * @author qigong on 5/1/15 */ public class MapperHelperForSharding { /** * 注册的通用Mapper接口 */ private Map<Class<?>, MapperEnhancer> registerMapper = new HashMap<Class<?>, MapperEnhancer>(); /** * 缓存msid和MapperTemplate */ private Map<String, MapperEnhancer> msIdCache = new HashMap<String, MapperEnhancer>(); /** * 缓存skip结果 */ private final Map<String, Boolean> msIdSkip = new HashMap<String, Boolean>(); /** * 缓存已经处理过的Collection<MappedStatement> */ private Set<Collection<MappedStatement>> collectionSet = new HashSet<Collection<MappedStatement>>(); /** * 是否使用的Spring */ private boolean spring = false; /** * 是否为Spring4.x以上版本 */ private boolean spring4 = false; /** * Spring版本号 */ private String springVersion; /** * 缓存初始化时的SqlSession */ private List<SqlSession> sqlSessions = new ArrayList<SqlSession>(); /** * 针对Spring注入需要处理的SqlSession * * @param sqlSessions */ public void setSqlSessions(SqlSession[] sqlSessions) { if (sqlSessions != null && sqlSessions.length > 0) { this.sqlSessions.addAll(Arrays.asList(sqlSessions)); } } /** * Spring初始化方法,使用Spring时需要配置init-method="initMapper" */ public void initMapper() { // 只有Spring会执行这个方法,所以Spring配置的时候,从这儿可以尝试获取Spring的版本 // 先判断Spring版本,对下面的操作有影响 // Spring4以上支持泛型注入,因此可以扫描通用Mapper if (!initSpringVersion()) { throw new RuntimeException("Error! Spring4 is necessary!"); } for (SqlSession sqlSession : sqlSessions) { processConfiguration(sqlSession.getConfiguration()); } } /** * 检测Spring版本号,Spring4.x以上支持泛型注入 */ private boolean initSpringVersion() { try { // 反射获取SpringVersion Class<?> springVersionClass = Class.forName("org.springframework.core.SpringVersion"); springVersion = (String) springVersionClass.getDeclaredMethod("getVersion", new Class<?>[0]).invoke(null, new Object[0]); spring = true; if (springVersion.indexOf(".") > 0) { int MajorVersion = Integer.parseInt(springVersion.substring(0, springVersion.indexOf("."))); if (MajorVersion > 3) { spring4 = true; } else { spring4 = false; } } } catch (Exception e) { spring = false; spring4 = false; } return spring && spring4; } /** * 通过通用Mapper接口获取对应的MapperTemplate * * @param mapperClass */ private MapperEnhancer fromMapperClass(Class<?> mapperClass) { Method[] methods = mapperClass.getDeclaredMethods(); Class<?> templateClass = null; Class<?> tempClass = null; Set<String> methodSet = new HashSet<String>(); for (Method method : methods) { if (method.isAnnotationPresent(ShardingExtensionMethod.class)) { ShardingExtensionMethod annotation = method.getAnnotation(ShardingExtensionMethod.class); tempClass = annotation.type(); methodSet.add(method.getName()); } if (templateClass == null) { templateClass = tempClass; } else if (templateClass != tempClass) { throw new RuntimeException("一个通用Mapper中只允许存在一个MapperTemplate子类!"); } } if (templateClass == null || !MapperEnhancer.class.isAssignableFrom(templateClass)) { throw new RuntimeException("接口中不存在包含type为MapperTemplate的Provider注解,这不是一个合法的通用Mapper接口类!"); } MapperEnhancer mapperEnhancer = null; try { mapperEnhancer = (MapperEnhancer) templateClass.getConstructor(Class.class).newInstance(mapperClass); } catch (Exception e) { throw new RuntimeException("实例化MapperTemplate对象失败:" + e.getMessage(), e); } // 注册方法 for (String methodName : methodSet) { try { mapperEnhancer.addMethodMap(methodName, templateClass.getMethod("enhancedShardingSQL", MappedStatement.class, Configuration.class, Long.class)); } catch (NoSuchMethodException e) { throw new RuntimeException(templateClass.getCanonicalName() + "中缺少enhancedShardingSQL方法!"); } } return mapperEnhancer; } /** * 注册通用Mapper接口 * * @param mapperClass * @throws Exception */ public void registerMapper(Class<?> mapperClass) { if (registerMapper.get(mapperClass) == null) { MapperEnhancer enhancer = fromMapperClass(mapperClass); registerMapper.put(mapperClass, enhancer); } else { throw new RuntimeException("已经注册过的通用Mapper[" + mapperClass.getCanonicalName() + "]不能多次注册!"); } } /** * 注册通用Mapper接口 * * @param mapperClass * @throws Exception */ public void registerMapper(String mapperClass) { try { registerMapper(Class.forName(mapperClass)); } catch (ClassNotFoundException e) { throw new RuntimeException("注册通用Mapper[" + mapperClass + "]失败,找不到该通用Mapper!"); } } /** * 方便Spring注入 * * @param mappers */ public void setMappers(String[] mappers) { if (mappers != null && mappers.length > 0) { for (String mapper : mappers) { registerMapper(mapper); } } } /** * 判断当前的接口方法是否需要进行拦截 * * @param msId * @return */ public boolean isMapperMethod(String msId) { if (msIdSkip.get(msId) != null) { return msIdSkip.get(msId); } for (Map.Entry<Class<?>, MapperEnhancer> entry : registerMapper.entrySet()) { if (entry.getValue().supportMethod(msId)) { msIdSkip.put(msId, true); return true; } } msIdSkip.put(msId, false); return false; } /** * 获取MapperTemplate * * @param msId * @return */ private MapperEnhancer getMapperTemplate(String msId) { MapperEnhancer mapperEnhancer = null; if (msIdCache.get(msId) != null) { mapperEnhancer = msIdCache.get(msId); } else { for (Map.Entry<Class<?>, MapperEnhancer> entry : registerMapper.entrySet()) { if (entry.getValue().supportMethod(msId)) { mapperEnhancer = entry.getValue(); break; } } msIdCache.put(msId, mapperEnhancer); } return mapperEnhancer; } /** * 重新设置SqlSource * * @param ms */ public void setSqlSource(MappedStatement ms, Configuration configuration) { MapperEnhancer mapperEnhancer = getMapperTemplate(ms.getId()); try { if (mapperEnhancer != null) { mapperEnhancer.setSqlSource(ms, configuration); } } catch (Exception e) { throw new RuntimeException("调用方法异常:" + e.getMessage(), e); } } /** * 处理configuration中全部的MappedStatement * * @param configuration */ public void processConfiguration(Configuration configuration) { Collection<MappedStatement> collection = configuration.getMappedStatements(); // 防止反复处理一个 if (collectionSet.contains(collection)) { return; } else { collectionSet.add(collection); } Collection<MappedStatement> tmpCollection = new HashSet<>(); tmpCollection.addAll(collection); Iterator<MappedStatement> iterator = tmpCollection.iterator(); while (iterator.hasNext()) { Object object = iterator.next(); if (object instanceof MappedStatement) { MappedStatement ms = (MappedStatement) object; if (isMapperMethod(ms.getId())) { setSqlSource(ms, configuration); } } } } }