package io.mycat.route.parser.druid.impl; import com.alibaba.druid.sql.ast.SQLStatement; import com.alibaba.druid.sql.visitor.SchemaStatVisitor; import com.alibaba.druid.stat.TableStat.Condition; import io.mycat.cache.LayerCachePool; import io.mycat.route.RouteResultset; import io.mycat.route.parser.druid.DruidParser; import io.mycat.route.parser.druid.DruidShardingParseInfo; import io.mycat.route.parser.druid.MycatSchemaStatVisitor; import io.mycat.route.parser.druid.RouteCalculateUnit; import io.mycat.server.config.node.SchemaConfig; import io.mycat.sqlengine.mpp.RangeValue; import io.mycat.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.sql.SQLNonTransientException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; /** * 对SQLStatement解析 * 主要通过visitor解析和statement解析:有些类型的SQLStatement通过visitor解析足够了, * 有些只能通过statement解析才能得到所有信息 * 有些需要通过两种方式解析才能得到完整信息 * @author wang.dw * */ public class DefaultDruidParser implements DruidParser { protected static final Logger LOGGER = LoggerFactory .getLogger(DefaultDruidParser.class); /** * 解析得到的结果 */ protected DruidShardingParseInfo ctx; private Map<String,String> tableAliasMap = new HashMap<String,String>(); private List<Condition> conditions = new ArrayList<Condition>(); public Map<String, String> getTableAliasMap() { return tableAliasMap; } public List<Condition> getConditions() { return conditions; } /** * 使用MycatSchemaStatVisitor解析,得到tables、tableAliasMap、conditions等 * @param schema * @param stmt */ public void parser(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt, String originSql,LayerCachePool cachePool,MycatSchemaStatVisitor schemaStatVisitor) throws SQLNonTransientException { ctx = new DruidShardingParseInfo(); //设置为原始sql,如果有需要改写sql的,可以通过修改SQLStatement中的属性,然后调用SQLStatement.toString()得到改写的sql ctx.setSql(originSql); //通过visitor解析 visitorParse(rrs,stmt,schemaStatVisitor); //通过Statement解析 statementParse(schema, rrs, stmt); //改写sql:如insert语句主键自增长的可以 changeSql(schema, rrs, stmt,cachePool); } /** * 子类可覆盖(如果visitorParse解析得不到表名、字段等信息的,就通过覆盖该方法来解析) * 子类覆盖该方法一般是将SQLStatement转型后再解析(如转型为MySqlInsertStatement) */ @Override public void statementParse(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt) throws SQLNonTransientException { } /** * 改写sql:如insert是 */ @Override public void changeSql(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt,LayerCachePool cachePool) throws SQLNonTransientException { } /** * 子类可覆盖(如果该方法解析得不到表名、字段等信息的,就覆盖该方法,覆盖成空方法,然后通过statementPparse去解析) * 通过visitor解析:有些类型的Statement通过visitor解析得不到表名、 * @param stmt */ @Override public void visitorParse(RouteResultset rrs, SQLStatement stmt,MycatSchemaStatVisitor visitor) throws SQLNonTransientException{ stmt.accept(visitor); List<List<Condition>> mergedConditionList = new ArrayList<List<Condition>>(); if(visitor.hasOrCondition()) {//包含or语句 //TODO //根据or拆分 mergedConditionList = visitor.splitConditions(); } else {//不包含OR语句 mergedConditionList.add(visitor.getConditions()); } if(visitor.getAliasMap() != null) { for(Map.Entry<String, String> entry : visitor.getAliasMap().entrySet()) { String key = entry.getKey(); String value = entry.getValue(); if(key != null && key.indexOf("`") >= 0) { key = key.replaceAll("`", ""); } if(value != null && value.indexOf("`") >= 0) { value = value.replaceAll("`", ""); } //表名前面带database的,去掉 if(key != null) { int pos = key.indexOf("."); if(pos> 0) { key = key.substring(pos + 1); } } if(key.equals(value)) { ctx.addTable(key.toUpperCase()); } else { tableAliasMap.put(key, value); } } ctx.setTableAliasMap(tableAliasMap); } ctx.setRouteCalculateUnits(this.buildRouteCalculateUnits(visitor, mergedConditionList)); } private List<RouteCalculateUnit> buildRouteCalculateUnits(SchemaStatVisitor visitor, List<List<Condition>> conditionList) { List<RouteCalculateUnit> retList = new ArrayList<RouteCalculateUnit>(); //遍历condition ,找分片字段 for(int i = 0; i < conditionList.size(); i++) { RouteCalculateUnit routeCalculateUnit = new RouteCalculateUnit(); for(Condition condition : conditionList.get(i)) { List<Object> values = condition.getValues(); if(values.size() == 0) { break; } if(checkConditionValues(values)) { String columnName = StringUtil.removeBackquote(condition.getColumn().getName().toUpperCase()); String tableName = StringUtil.removeBackquote(condition.getColumn().getTable().toUpperCase()); if(visitor.getAliasMap() != null && visitor.getAliasMap().get(condition.getColumn().getTable()) == null) {//子查询的别名条件忽略掉,不参数路由计算,否则后面找不到表 continue; } String operator = condition.getOperator(); //只处理between ,in和=3中操作符 if(operator.equals("between")) { RangeValue rv = new RangeValue(values.get(0), values.get(1), RangeValue.EE); routeCalculateUnit.addShardingExpr(tableName.toUpperCase(), columnName, rv); } else if(operator.equals("=") || operator.toLowerCase().equals("in")){ //只处理=号和in操作符,其他忽略 routeCalculateUnit.addShardingExpr(tableName.toUpperCase(), columnName, values.toArray()); } } } retList.add(routeCalculateUnit); } return retList; } private boolean checkConditionValues(List<Object> values) { for(Object value : values) { if(value != null && !value.toString().equals("")) { return true; } } return false; } public DruidShardingParseInfo getCtx() { return ctx; } }