package io.mycat.route.parser.druid;
import io.mycat.route.util.RouterUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLName;
import com.alibaba.druid.sql.ast.SQLObject;
import com.alibaba.druid.sql.ast.expr.SQLBetweenExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator;
import com.alibaba.druid.sql.ast.expr.SQLCharExpr;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr;
import com.alibaba.druid.sql.ast.statement.SQLAlterTableItem;
import com.alibaba.druid.sql.ast.statement.SQLAlterTableStatement;
import com.alibaba.druid.sql.ast.statement.SQLDeleteStatement;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.stat.TableStat;
import com.alibaba.druid.stat.TableStat.Column;
import com.alibaba.druid.stat.TableStat.Condition;
import com.alibaba.druid.stat.TableStat.Mode;
/**
* Druid解析器中用来从ast语法中提取表名、条件、字段等的vistor
* @author wang.dw
*
*/
public class MycatSchemaStatVisitor extends MySqlSchemaStatVisitor {
private boolean hasOrCondition = false;
private List<WhereUnit> whereUnits = new CopyOnWriteArrayList<WhereUnit>();
private List<WhereUnit> storedwhereUnits = new CopyOnWriteArrayList<WhereUnit>();
private void reset() {
this.conditions.clear();
this.whereUnits.clear();
this.hasOrCondition = false;
}
public List<WhereUnit> getWhereUnits() {
return whereUnits;
}
public boolean hasOrCondition() {
return hasOrCondition;
}
@Override
public boolean visit(SQLSelectStatement x) {
setAliasMap();
// getAliasMap().put("DUAL", null);
return true;
}
@Override
public boolean visit(SQLBetweenExpr x) {
String begin = null;
if(x.beginExpr instanceof SQLCharExpr)
{
begin= (String) ( (SQLCharExpr)x.beginExpr).getValue();
} else {
begin = x.beginExpr.toString();
}
String end = null;
if(x.endExpr instanceof SQLCharExpr)
{
end= (String) ( (SQLCharExpr)x.endExpr).getValue();
} else {
end = x.endExpr.toString();
}
Column column = getColumn(x);
if (column == null) {
return true;
}
Condition condition = null;
for (Condition item : this.getConditions()) {
if (item.getColumn().equals(column) && item.getOperator().equals("between")) {
condition = item;
break;
}
}
if (condition == null) {
condition = new Condition();
condition.setColumn(column);
condition.setOperator("between");
this.conditions.add(condition);
}
condition.getValues().add(begin);
condition.getValues().add(end);
return true;
}
@Override
protected Column getColumn(SQLExpr expr) {
Map<String, String> aliasMap = getAliasMap();
if (aliasMap == null) {
return null;
}
if (expr instanceof SQLPropertyExpr) {
SQLExpr owner = ((SQLPropertyExpr) expr).getOwner();
String column = ((SQLPropertyExpr) expr).getName();
if (owner instanceof SQLIdentifierExpr) {
String tableName = ((SQLIdentifierExpr) owner).getName();
String table = tableName;
if (aliasMap.containsKey(table)) {
table = aliasMap.get(table);
}
if (variants.containsKey(table)) {
return null;
}
if (table != null) {
return new Column(table, column);
}
return handleSubQueryColumn(tableName, column);
}
return null;
}
if (expr instanceof SQLIdentifierExpr) {
Column attrColumn = (Column) expr.getAttribute(ATTR_COLUMN);
if (attrColumn != null) {
return attrColumn;
}
String column = ((SQLIdentifierExpr) expr).getName();
String table = getCurrentTable();
if (table != null && aliasMap.containsKey(table)) {
table = aliasMap.get(table);
if (table == null) {
return null;
}
}
if (table != null) {
return new Column(table, column);
}
if (variants.containsKey(column)) {
return null;
}
return new Column("UNKNOWN", column);
}
if(expr instanceof SQLBetweenExpr) {
SQLBetweenExpr betweenExpr = (SQLBetweenExpr)expr;
if(betweenExpr.getTestExpr() != null) {
String tableName = null;
String column = null;
if(betweenExpr.getTestExpr() instanceof SQLPropertyExpr) {//字段带别名的
tableName = ((SQLIdentifierExpr)((SQLPropertyExpr) betweenExpr.getTestExpr()).getOwner()).getName();
column = ((SQLPropertyExpr) betweenExpr.getTestExpr()).getName();
SQLObject query = this.subQueryMap.get(tableName);
if(query == null) {
if (aliasMap.containsKey(tableName)) {
tableName = aliasMap.get(tableName);
}
return new Column(tableName, column);
}
return handleSubQueryColumn(tableName, column);
} else if(betweenExpr.getTestExpr() instanceof SQLIdentifierExpr) {
column = ((SQLIdentifierExpr) betweenExpr.getTestExpr()).getName();
//字段不带别名的,此处如果是多表,容易出现ambiguous,
//不知道这个字段是属于哪个表的,fdbparser用了defaultTable,即join语句的leftTable
tableName = getOwnerTableName(betweenExpr,column);
}
String table = tableName;
if (aliasMap.containsKey(table)) {
table = aliasMap.get(table);
}
if (variants.containsKey(table)) {
return null;
}
if (table != null&&!"".equals(table)) {
return new Column(table, column);
}
}
}
return null;
}
/**
* 从between语句中获取字段所属的表名。
* 对于容易出现ambiguous的(字段不知道到底属于哪个表),实际应用中必须使用别名来避免歧义
* @param betweenExpr
* @param column
* @return
*/
private String getOwnerTableName(SQLBetweenExpr betweenExpr,String column) {
if(tableStats.size() == 1) {//只有一个表,直接返回这一个表名
return tableStats.keySet().iterator().next().getName();
} else if(tableStats.size() == 0) {//一个表都没有,返回空串
return "";
} else {//多个表名
for(Column col : columns) {//从columns中找表名
if(col.getName().equals(column)) {
return col.getTable();
}
}
//前面没找到表名的,自己从parent中解析
SQLObject parent = betweenExpr.getParent();
if(parent instanceof SQLBinaryOpExpr)
{
parent=parent.getParent();
}
if(parent instanceof MySqlSelectQueryBlock) {
MySqlSelectQueryBlock select = (MySqlSelectQueryBlock) parent;
if(select.getFrom() instanceof SQLJoinTableSource) {//多表连接
SQLJoinTableSource joinTableSource = (SQLJoinTableSource)select.getFrom();
return joinTableSource.getLeft().toString();//将left作为主表,此处有不严谨处,但也是实在没有办法,如果要准确,字段前带表名或者表的别名即可
} else if(select.getFrom() instanceof SQLExprTableSource) {//单表
return select.getFrom().toString();
}
}
else if(parent instanceof SQLUpdateStatement) {
SQLUpdateStatement update = (SQLUpdateStatement) parent;
return update.getTableName().getSimpleName();
} else if(parent instanceof SQLDeleteStatement) {
SQLDeleteStatement delete = (SQLDeleteStatement) parent;
return delete.getTableName().getSimpleName();
} else {
}
}
return "";
}
@Override
public boolean visit(SQLBinaryOpExpr x) {
x.getLeft().setParent(x);
x.getRight().setParent(x);
switch (x.getOperator()) {
case Equality:
case LessThanOrEqualOrGreaterThan:
case Is:
case IsNot:
handleCondition(x.getLeft(), x.getOperator().name, x.getRight());
handleCondition(x.getRight(), x.getOperator().name, x.getLeft());
handleRelationship(x.getLeft(), x.getOperator().name, x.getRight());
break;
case BooleanOr:
//永真条件,where条件抛弃
if(!RouterUtil.isConditionAlwaysTrue(x)) {
hasOrCondition = true;
WhereUnit whereUnit = null;
if(conditions.size() > 0) {
whereUnit = new WhereUnit();
whereUnit.setFinishedParse(true);
whereUnit.addOutConditions(getConditions());
WhereUnit innerWhereUnit = new WhereUnit(x);
whereUnit.addSubWhereUnit(innerWhereUnit);
} else {
whereUnit = new WhereUnit(x);
whereUnit.addOutConditions(getConditions());
}
whereUnits.add(whereUnit);
}
return false;
case Like:
case NotLike:
case NotEqual:
case GreaterThan:
case GreaterThanOrEqual:
case LessThan:
case LessThanOrEqual:
default:
break;
}
return true;
}
/**
* 分解条件
*/
public List<List<Condition>> splitConditions() {
//按照or拆分
for(WhereUnit whereUnit : whereUnits) {
splitUntilNoOr(whereUnit);
}
this.storedwhereUnits.addAll(whereUnits);
loopFindSubWhereUnit(whereUnits);
//拆分后的条件块解析成Condition列表
for(WhereUnit whereUnit : storedwhereUnits) {
this.getConditionsFromWhereUnit(whereUnit);
}
//多个WhereUnit组合:多层集合的组合
return mergedConditions();
}
/**
* 循环寻找子WhereUnit(实际是嵌套的or)
* @param whereUnitList
*/
private void loopFindSubWhereUnit(List<WhereUnit> whereUnitList) {
List<WhereUnit> subWhereUnits = new ArrayList<WhereUnit>();
for(WhereUnit whereUnit : whereUnitList) {
if(whereUnit.getSplitedExprList().size() > 0) {
List<SQLExpr> removeSplitedList = new ArrayList<SQLExpr>();
for(SQLExpr sqlExpr : whereUnit.getSplitedExprList()) {
reset();
if(isExprHasOr(sqlExpr)) {
removeSplitedList.add(sqlExpr);
WhereUnit subWhereUnit = this.whereUnits.get(0);
splitUntilNoOr(subWhereUnit);
whereUnit.addSubWhereUnit(subWhereUnit);
subWhereUnits.add(subWhereUnit);
} else {
this.conditions.clear();
}
}
if(removeSplitedList.size() > 0) {
whereUnit.getSplitedExprList().removeAll(removeSplitedList);
}
}
subWhereUnits.addAll(whereUnit.getSubWhereUnit());
}
if(subWhereUnits.size() > 0) {
loopFindSubWhereUnit(subWhereUnits);
}
}
private boolean isExprHasOr(SQLExpr expr) {
expr.accept(this);
return hasOrCondition;
}
private List<List<Condition>> mergedConditions() {
if(storedwhereUnits.size() == 0) {
return new ArrayList<List<Condition>>();
}
for(WhereUnit whereUnit : storedwhereUnits) {
mergeOneWhereUnit(whereUnit);
}
return getMergedConditionList(storedwhereUnits);
}
/**
* 一个WhereUnit内递归
* @param whereUnit
*/
private void mergeOneWhereUnit(WhereUnit whereUnit) {
if(whereUnit.getSubWhereUnit().size() > 0) {
for(WhereUnit sub : whereUnit.getSubWhereUnit()) {
mergeOneWhereUnit(sub);
}
if(whereUnit.getSubWhereUnit().size() > 1) {
List<List<Condition>> mergedConditionList = getMergedConditionList(whereUnit.getSubWhereUnit());
if(whereUnit.getOutConditions().size() > 0) {
for(int i = 0; i < mergedConditionList.size() ; i++) {
mergedConditionList.get(i).addAll(whereUnit.getOutConditions());
}
}
whereUnit.setConditionList(mergedConditionList);
} else if(whereUnit.getSubWhereUnit().size() == 1) {
if(whereUnit.getOutConditions().size() > 0 && whereUnit.getSubWhereUnit().get(0).getConditionList().size() > 0) {
for(int i = 0; i < whereUnit.getSubWhereUnit().get(0).getConditionList().size() ; i++) {
whereUnit.getSubWhereUnit().get(0).getConditionList().get(i).addAll(whereUnit.getOutConditions());
}
}
whereUnit.getConditionList().addAll(whereUnit.getSubWhereUnit().get(0).getConditionList());
}
} else {
//do nothing
}
}
/**
* 条件合并:多个WhereUnit中的条件组合
* @return
*/
private List<List<Condition>> getMergedConditionList(List<WhereUnit> whereUnitList) {
List<List<Condition>> mergedConditionList = new ArrayList<List<Condition>>();
if(whereUnitList.size() == 0) {
return mergedConditionList;
}
mergedConditionList.addAll(whereUnitList.get(0).getConditionList());
for(int i = 1; i < whereUnitList.size(); i++) {
mergedConditionList = merge(mergedConditionList, whereUnitList.get(i).getConditionList());
}
return mergedConditionList;
}
/**
* 两个list中的条件组合
* @param list1
* @param list2
* @return
*/
private List<List<Condition>> merge(List<List<Condition>> list1, List<List<Condition>> list2) {
if(list1.size() == 0) {
return list2;
} else if (list2.size() == 0) {
return list1;
}
List<List<Condition>> retList = new ArrayList<List<Condition>>();
for(int i = 0; i < list1.size(); i++) {
for(int j = 0; j < list2.size(); j++) {
List<Condition> listTmp = new ArrayList<Condition>();
listTmp.addAll(list1.get(i));
listTmp.addAll(list2.get(j));
retList.add(listTmp);
}
}
return retList;
}
private void getConditionsFromWhereUnit(WhereUnit whereUnit) {
List<List<Condition>> retList = new ArrayList<List<Condition>>();
//or语句外层的条件:如where condition1 and (condition2 or condition3),condition1就会在外层条件中,因为之前提取
List<Condition> outSideCondition = new ArrayList<Condition>();
// stashOutSideConditions();
outSideCondition.addAll(conditions);
this.conditions.clear();
for(SQLExpr sqlExpr : whereUnit.getSplitedExprList()) {
sqlExpr.accept(this);
List<Condition> conditions = new ArrayList<Condition>();
conditions.addAll(getConditions());
conditions.addAll(outSideCondition);
retList.add(conditions);
this.conditions.clear();
}
whereUnit.setConditionList(retList);
for(WhereUnit subWhere : whereUnit.getSubWhereUnit()) {
getConditionsFromWhereUnit(subWhere);
}
}
/**
* 递归拆分OR
*
* @param whereUnit
* TODO:考虑嵌套or语句,条件中有子查询、 exists等很多种复杂情况是否能兼容
*/
private void splitUntilNoOr(WhereUnit whereUnit) {
if(whereUnit.isFinishedParse()) {
if(whereUnit.getSubWhereUnit().size() > 0) {
for(int i = 0; i < whereUnit.getSubWhereUnit().size(); i++) {
splitUntilNoOr(whereUnit.getSubWhereUnit().get(i));
}
}
} else {
SQLBinaryOpExpr expr = whereUnit.getCanSplitExpr();
if(expr.getOperator() == SQLBinaryOperator.BooleanOr) {
// whereUnit.addSplitedExpr(expr.getRight());
addExprIfNotFalse(whereUnit, expr.getRight());
if(expr.getLeft() instanceof SQLBinaryOpExpr) {
whereUnit.setCanSplitExpr((SQLBinaryOpExpr)expr.getLeft());
splitUntilNoOr(whereUnit);
} else {
addExprIfNotFalse(whereUnit, expr.getLeft());
}
} else {
addExprIfNotFalse(whereUnit, expr);
whereUnit.setFinishedParse(true);
}
}
}
private void addExprIfNotFalse(WhereUnit whereUnit, SQLExpr expr) {
//非永假条件加入路由计算
if(!RouterUtil.isConditionAlwaysFalse(expr)) {
whereUnit.addSplitedExpr(expr);
}
}
@Override
public boolean visit(SQLAlterTableStatement x) {
String tableName = x.getName().toString();
TableStat stat = getTableStat(tableName,tableName);
stat.incrementAlterCount();
setCurrentTable(x, tableName);
for (SQLAlterTableItem item : x.getItems()) {
item.setParent(x);
item.accept(this);
}
return false;
}
// DUAL
public boolean visit(MySqlDeleteStatement x) {
setAliasMap();
setMode(x, Mode.Delete);
accept(x.getFrom());
accept(x.getUsing());
x.getTableSource().accept(this);
if (x.getTableSource() instanceof SQLExprTableSource) {
SQLName tableName = (SQLName) ((SQLExprTableSource) x.getTableSource()).getExpr();
String ident = tableName.toString();
setCurrentTable(x, ident);
TableStat stat = this.getTableStat(ident,ident);
stat.incrementDeleteCount();
}
accept(x.getWhere());
accept(x.getOrderBy());
accept(x.getLimit());
return false;
}
public void endVisit(MySqlDeleteStatement x) {
}
}