package com.mogujie.trade.tsharding.route.orm; import com.mogujie.trade.tsharding.client.ShardingCaculator; import org.apache.ibatis.builder.StaticSqlSource; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.ParameterMapping; import org.apache.ibatis.mapping.SqlSource; import org.apache.ibatis.scripting.defaults.RawSqlSource; import org.apache.ibatis.scripting.xmltags.*; import org.apache.ibatis.session.Configuration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.List; /** * Mappper sql增强 * * @author qigong on 5/1/15 */ public class MapperResourceEnhancer extends MapperEnhancer{ Logger logger = LoggerFactory.getLogger(MapperResourceEnhancer.class); public MapperResourceEnhancer(Class<?> mapperClass) { super(mapperClass); } public SqlSource enhancedShardingSQL(MappedStatement ms, Configuration configuration, Long shardingPara) { String tableName = ShardingCaculator.caculateTableName(shardingPara); SqlSource result = null; try { if (ms.getSqlSource() instanceof DynamicSqlSource) { DynamicSqlSource sqlSource = (DynamicSqlSource) ms.getSqlSource(); Class sqlSourceClass = sqlSource.getClass(); Field sqlNodeField = sqlSourceClass.getDeclaredField("rootSqlNode"); sqlNodeField.setAccessible(true); MixedSqlNode rootSqlNode = (MixedSqlNode) sqlNodeField.get(sqlSource); Class mixedSqlNodeClass = rootSqlNode.getClass(); Field contentsField = mixedSqlNodeClass.getDeclaredField("contents"); contentsField.setAccessible(true); List<SqlNode> textSqlNodes = (List<SqlNode>) contentsField.get(rootSqlNode); List<SqlNode> newSqlNodesList = new ArrayList(); //StaticTextSqlNode Class textSqlNodeClass = textSqlNodes.get(0).getClass(); Field textField = textSqlNodeClass.getDeclaredField("text"); textField.setAccessible(true); for (SqlNode node : textSqlNodes) { if (node instanceof StaticTextSqlNode) { StaticTextSqlNode textSqlNode = (StaticTextSqlNode) node; String text = (String) textField.get(textSqlNode); if(!text.contains("TradeOrder")){ newSqlNodesList.add(node); }else { newSqlNodesList.add(new StaticTextSqlNode(replaceWithShardingTableName(text, tableName, shardingPara))); } }else{ newSqlNodesList.add(node); } } MixedSqlNode newrootSqlNode = new MixedSqlNode(newSqlNodesList); result = new DynamicSqlSource(configuration, newrootSqlNode); return result; } else if (ms.getSqlSource() instanceof RawSqlSource) { RawSqlSource sqlSource = (RawSqlSource) ms.getSqlSource(); Class sqlSourceClass = sqlSource.getClass(); Field sqlSourceField = sqlSourceClass.getDeclaredField("sqlSource"); sqlSourceField.setAccessible(true); StaticSqlSource staticSqlSource = (StaticSqlSource) sqlSourceField.get(sqlSource); Field sqlField = staticSqlSource.getClass().getDeclaredField("sql"); Field parameterMappingsField = staticSqlSource.getClass().getDeclaredField("parameterMappings"); sqlField.setAccessible(true); parameterMappingsField.setAccessible(true); //sql处理 String sql = (String) sqlField.get(staticSqlSource); if(!sql.contains("TradeOrder")){ result = sqlSource; }else { sql = replaceWithShardingTableName(sql, tableName, shardingPara); result = new RawSqlSource(configuration, sql, null); //为sqlSource对象设置mappering参数 StaticSqlSource newStaticSqlSource = (StaticSqlSource) sqlSourceField.get(result); List<ParameterMapping> parameterMappings = (List<ParameterMapping>)parameterMappingsField.get(staticSqlSource); parameterMappingsField.set(newStaticSqlSource, parameterMappings); } return result; } else { throw new RuntimeException("wrong sqlSource type!" + ms.getResource()); } } catch (Exception e) { logger.error("reflect error!, ms resources:" + ms.getResource(), e); } return result; } private String replaceWithShardingTableName(String text, String tableName, Long shardingPara){ if(text.contains(" TradeOrderPressureTest")){ return text.replace(" TradeOrderPressureTest", " TradeOrderPressureTest" + ShardingCaculator.getNumberWithZeroSuffix(shardingPara)); } return text.replace(" TradeOrder", " " + tableName); } }