package com.mogujie.trade.tsharding.route.orm; import com.mogujie.trade.tsharding.client.ShardingCaculator; import javassist.ClassPool; import javassist.CtClass; import javassist.CtMethod; import javassist.bytecode.ClassFile; import javassist.bytecode.ConstPool; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.ResultMap; import org.apache.ibatis.mapping.SqlSource; import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.factory.DefaultObjectFactory; import org.apache.ibatis.reflection.factory.ObjectFactory; import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory; import org.apache.ibatis.reflection.wrapper.ObjectWrapperFactory; import org.apache.ibatis.session.Configuration; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.HashMap; import java.util.List; import java.util.Map; /** * 通用Mapper增强基类,扩展Mapper sql时需要继承该类 * * @author qigong on 5/1/15 */ public abstract class MapperEnhancer { private static ClassPool pool = ClassPool.getDefault(); private Map<String, Method> methodMap = new HashMap<String, Method>(); private Class<?> mapperClass; public MapperEnhancer(Class<?> mapperClass) { this.mapperClass = mapperClass; } /** * 代码增加方法标记 * * @param record */ public String enhancedShardingSQL(Object record) { return "enhancedShardingSQL"; } public MapperEnhancer() { super(); } /** * 对mapper进行增强,生成新的mapper,并主动加载新mapper类到classloader * * @param mapperClassName */ public static void enhanceMapperClass(String mapperClassName) throws Exception { Class originClass = Class.forName(mapperClassName); Method[] originMethods = originClass.getDeclaredMethods(); CtClass cc = pool.get(mapperClassName); for (CtMethod ctMethod : cc.getDeclaredMethods()) { CtClass enhanceClass = pool.makeInterface(mapperClassName + "Sharding" + ctMethod.getName()); for (long i = 0L; i < 512; i++) { CtMethod newMethod = new CtMethod(ctMethod.getReturnType(), ctMethod.getName() + ShardingCaculator.getNumberWithZeroSuffix(i), ctMethod.getParameterTypes(), enhanceClass); Method method = getOriginMethod(newMethod, originMethods); if(method.getParameterAnnotations()[0].length > 0) { ClassFile ccFile = enhanceClass.getClassFile(); ConstPool constPool = ccFile.getConstPool(); //拷贝注解信息和注解内容,以支持mybatis mapper类的动态绑定 newMethod.getMethodInfo().addAttribute(MapperAnnotationEnhancer.duplicateParameterAnnotationsAttribute(constPool, method)); } enhanceClass.addMethod(newMethod); } Class<?> loadThisClass = enhanceClass.toClass(); //2015.09.22后不再输出类到本地 // enhanceClass.writeFile("."); } } private static Method getOriginMethod(CtMethod ctMethod, Method[] originMethods) { for (Method method : originMethods) { int len = ctMethod.getName().length(); if (ctMethod.getName().substring(0, len-4).equals(method.getName())) { return method; } } throw new RuntimeException("enhanceMapperClass find method error!"); } /** * 添加映射方法 * * @param methodName * @param method */ public void addMethodMap(String methodName, Method method) { methodMap.put(methodName, method); } private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory(); private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory(); /** * 反射对象,增加对低版本Mybatis的支持 * * @param object 反射对象 * @return */ public static MetaObject forObject(Object object) { return MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY); } /** * 是否支持该通用方法 * * @param msId * @return */ public boolean supportMethod(String msId) { Class<?> mapperClass = getMapperClass(msId); if (this.mapperClass.isAssignableFrom(mapperClass)) { String methodName = getMethodName(msId); return methodMap.get(methodName) != null; } return false; } /** * 重新设置SqlSource * * @param ms * @param sqlSource */ protected void setSqlSource(MappedStatement ms, SqlSource sqlSource) { MetaObject msObject = forObject(ms); msObject.setValue("sqlSource", sqlSource); } /** * 重新设置SqlSource * * @param ms * @throws java.lang.reflect.InvocationTargetException * @throws IllegalAccessException */ public void setSqlSource(MappedStatement ms, Configuration configuration) throws Exception { Method method = methodMap.get(getMethodName(ms)); try { if (method.getReturnType() == Void.TYPE) { method.invoke(this, ms); } else if (SqlSource.class.isAssignableFrom(method.getReturnType())) { //代码增强 扩充为512个方法。 for (long i = 0; i < 512; i++) { //新的带sharding的sql SqlSource sqlSource = (SqlSource) method.invoke(this, ms, configuration, i); String newMsId = ms.getId() + ShardingCaculator.getNumberWithZeroSuffix(i); newMsId = newMsId.replace("Mapper.", "MapperSharding" + getMethodName(ms) + "."); //添加到ms库中 MappedStatement newMs = copyFromMappedStatement(ms, sqlSource, newMsId); configuration.addMappedStatement(newMs); setSqlSource(newMs, sqlSource); } } else { throw new RuntimeException("自定义Mapper方法返回类型错误,可选的返回类型为void和SqlNode!"); } } catch (IllegalAccessException e) { throw new RuntimeException(e); } catch (InvocationTargetException e) { throw new RuntimeException(e.getTargetException() != null ? e.getTargetException() : e); } } protected MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource, String newMsId) { MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), newMsId, newSqlSource, ms.getSqlCommandType()); builder.resource(ms.getResource()); builder.fetchSize(ms.getFetchSize()); builder.statementType(ms.getStatementType()); builder.keyGenerator(ms.getKeyGenerator()); // setStatementTimeout() builder.timeout(ms.getTimeout()); // setParameterMap() builder.parameterMap(ms.getParameterMap()); // setStatementResultMap() List<ResultMap> resultMaps = ms.getResultMaps(); builder.resultMaps(resultMaps); builder.resultSetType(ms.getResultSetType()); // setStatementCache() builder.cache(ms.getCache()); builder.flushCacheRequired(ms.isFlushCacheRequired()); builder.useCache(ms.isUseCache()); return builder.build(); } /** * 根据msId获取接口类 * * @param msId * @return * @throws ClassNotFoundException */ public static Class<?> getMapperClass(String msId) { String mapperClassStr = msId.substring(0, msId.lastIndexOf(".")); try { return Class.forName(mapperClassStr); } catch (ClassNotFoundException e) { throw new RuntimeException("无法获取Mapper接口信息:" + msId); } } /** * 获取执行的方法名 * * @param ms * @return */ public static String getMethodName(MappedStatement ms) { return getMethodName(ms.getId()); } /** * 获取执行的方法名 * * @param msId * @return */ public static String getMethodName(String msId) { return msId.substring(msId.lastIndexOf(".") + 1); } }