package jakiro.mybatis; import jakiro.datasource.threadlocal.DataSourceContextHolder; import jakiro.util.Pair; import jakiro.util.ReflectionUtils; import jakiro.util.SQLParser; import jakiro.util.StringUtils; import jakiro.util.Validate; import java.sql.Connection; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import java.util.Properties; import java.util.concurrent.ConcurrentHashMap; import org.apache.ibatis.executor.statement.RoutingStatementHandler; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Plugin; import org.apache.ibatis.plugin.Signature; @Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class})}) public class Interceptor4DB implements Interceptor { // 输入 private Map<String, String> tableNameVsDataSource = new HashMap<String, String>(); private Map<String, TableNameHandler> tableNameVsHandler = new HashMap<String, TableNameHandler>(); // cache sql parser private Map<String, Pair<String, String>> idVSTableNameType = new ConcurrentHashMap<String, Pair<String, String>>(); // cache result private Map<String, Pair<String, String>> id4TableName = new ConcurrentHashMap<String, Pair<String, String>>(); private Map<String, Object> id4DataSource = new ConcurrentHashMap<String, Object>(); private final static Pair<String, String> Skip4HandleName = Pair.of("", ""); private static final Object Skip4DataSource = new Object(); // 配置 private String prefix = "_shared_a0b9c8d7e6_"; public void setPrefix(String prefix) { this.prefix = prefix; } private boolean noCache4Sql = false; public void setNoCache4Sql(boolean noCache4Sql) { this.noCache4Sql = noCache4Sql; } private boolean cleanThreadLocalFirst = false; public void setCleanThreadLocalFirst(boolean cleanThreadLocalFirst) { this.cleanThreadLocalFirst = cleanThreadLocalFirst; } @Override public Object intercept(Invocation invocation) throws Throwable { StatementHandler statementHandler = (StatementHandler) invocation.getTarget(); MappedStatement mappedStatement = null; if (statementHandler instanceof RoutingStatementHandler) { StatementHandler delegate = (StatementHandler) ReflectionUtils.getFieldValue(statementHandler, "delegate"); mappedStatement = (MappedStatement) ReflectionUtils.getFieldValue(delegate, "mappedStatement"); } else { mappedStatement = (MappedStatement) ReflectionUtils.getFieldValue(statementHandler, "mappedStatement"); } String mapperId = mappedStatement.getId(); Object params = statementHandler.getBoundSql().getParameterObject(); handleTableName(mapperId, statementHandler, params); handleDataSource(mapperId); try { return invocation.proceed(); } finally { DataSourceContextHolder.clearDataSourceName(); } } private void handleDataSource(String mapperId) throws Exception { if (cleanThreadLocalFirst) { DataSourceContextHolder.clearDataSourceName(); } if (!id4DataSource.containsKey(mapperId)) { Pair<String, String> tableNameAndType = idVSTableNameType.get(mapperId); if (tableNameAndType == null) { DataSourceContextHolder.clearDataSourceName(); throw new Exception("Parse Sql Failure !!!"); } String ds = tableNameVsDataSource.get(tableNameAndType.getLeft()); if (ds == null) { id4DataSource.put(mapperId, Skip4DataSource); } else { DataSourceContextHolder.setDataSourceName(ds); id4DataSource.put(mapperId, ds); } } else if (id4DataSource.get(mapperId) == Skip4DataSource) { } else { DataSourceContextHolder.setDataSourceName((String) id4DataSource.get(mapperId)); } } private void handleTableName(String id, StatementHandler statementHandler, Object params) { Pair<String, String> tn = id4TableName.get(id); if (tn == null || noCache4Sql) { tn = SQLParser.findTableNameAndType(statementHandler.getBoundSql().getSql()); Validate.notNull(tn); idVSTableNameType.put(id, Pair.of(tn.getLeft(), tn.getRight())); tn = tn.getLeft().startsWith(prefix) ? tn : Skip4HandleName; id4TableName.put(id, tn); } if (tn != Skip4HandleName && tableNameVsHandler.get(tn.getLeft()) != null) { String p = tableNameVsHandler.get(tn.getLeft()).getTargetTableName(tn.getRight(), tn.getLeft(), params, id); String sql = statementHandler.getBoundSql().getSql(); if (StringUtils.isNotBlank(sql) && StringUtils.isNotBlank(p)) { String nsql = sql.replaceAll(tn.getLeft(), p); ReflectionUtils.setFieldValue(statementHandler.getBoundSql(), "sql", nsql); } } } @Override public Object plugin(Object target) { return Plugin.wrap(target, this); } @Override public void setProperties(Properties properties) {} public void setTableNameVsDataSource(Properties properties) { if (properties == null || properties.size() == 0) return; for (Entry<Object, Object> e : properties.entrySet()) { tableNameVsDataSource.put((String) e.getKey(), (String) e.getValue()); } } public void setTableNameVsHandler(Properties properties) { if (properties == null || properties.size() == 0) { return; } for (Entry<Object, Object> e : properties.entrySet()) { Object o = null; try { Class<?> c = Class.forName((String) e.getValue()); o = c.newInstance(); Validate.isTrue(o instanceof TableNameHandler); } catch (Exception ec) { Validate.isTrue(false, ec.toString()); } tableNameVsHandler.put((String) e.getKey(), (TableNameHandler) o); } } }