/* * Copyright 1999-2015 dangdang.com. * <p> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * </p> */ package com.dangdang.ddframe.rdb.sharding.parser.visitor; import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.SQLObject; import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr; import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr; import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr; import com.alibaba.druid.sql.ast.statement.SQLExprTableSource; import com.alibaba.druid.sql.visitor.SQLEvalVisitor; import com.alibaba.druid.sql.visitor.SQLEvalVisitorUtils; import com.alibaba.druid.util.JdbcUtils; import com.dangdang.ddframe.rdb.sharding.api.rule.ShardingRule; import com.dangdang.ddframe.rdb.sharding.constants.DatabaseType; import com.dangdang.ddframe.rdb.sharding.parser.result.SQLParsedResult; import com.dangdang.ddframe.rdb.sharding.parser.result.merger.AggregationColumn; import com.dangdang.ddframe.rdb.sharding.parser.result.merger.AggregationColumn.AggregationType; import com.dangdang.ddframe.rdb.sharding.parser.result.merger.GroupByColumn; import com.dangdang.ddframe.rdb.sharding.parser.result.merger.OrderByColumn; import com.dangdang.ddframe.rdb.sharding.parser.result.merger.OrderByColumn.OrderByType; import com.dangdang.ddframe.rdb.sharding.parser.result.router.Condition; import com.dangdang.ddframe.rdb.sharding.parser.result.router.Condition.BinaryOperator; import com.dangdang.ddframe.rdb.sharding.parser.result.router.Condition.Column; import com.dangdang.ddframe.rdb.sharding.parser.result.router.ConditionContext; import com.dangdang.ddframe.rdb.sharding.parser.result.router.Table; import com.dangdang.ddframe.rdb.sharding.parser.visitor.basic.mysql.MySQLEvalVisitor; import com.dangdang.ddframe.rdb.sharding.util.SQLUtil; import com.google.common.base.Optional; import com.google.common.base.Supplier; import com.google.common.collect.Multimap; import com.google.common.collect.Multimaps; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.Setter; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Set; import java.util.TreeMap; import java.util.TreeSet; /** * 解析过程的上下文对象. * * @author zhangliang */ @Getter public final class ParseContext { private static final String AUTO_GEN_TOKE_KEY_TEMPLATE = "sharding_auto_gen_%d"; private static final String SHARDING_GEN_ALIAS = "sharding_gen_%s"; private final String autoGenTokenKey; private final SQLParsedResult parsedResult = new SQLParsedResult(); private final int parseContextIndex; @Setter private ShardingRule shardingRule; @Setter private boolean hasOrCondition; private final ConditionContext currentConditionContext = new ConditionContext(); private Table currentTable; private int selectItemsCount; private final Collection<String> selectItems = new HashSet<>(); private boolean hasAllColumn; @Setter private ParseContext parentParseContext; private List<ParseContext> subParseContext = new LinkedList<>(); private int itemIndex; private final Multimap<String, String> tableShardingColumnsMap = Multimaps.newSetMultimap(new TreeMap<String, Collection<String>>(String.CASE_INSENSITIVE_ORDER), new Supplier<Set<String>>() { @Override public Set<String> get() { return new TreeSet<>(String.CASE_INSENSITIVE_ORDER); } }); public ParseContext(final int parseContextIndex) { this.parseContextIndex = parseContextIndex; autoGenTokenKey = String.format(AUTO_GEN_TOKE_KEY_TEMPLATE, parseContextIndex); } /** * 增加查询投射项数量. */ public void increaseItemIndex() { itemIndex++; } /** * 设置当前正在访问的表. * * @param currentTableName 表名称 * @param currentAlias 表别名 */ public void setCurrentTable(final String currentTableName, final Optional<String> currentAlias) { Table table = new Table(SQLUtil.getExactlyValue(currentTableName), currentAlias.isPresent() ? Optional.of(SQLUtil.getExactlyValue(currentAlias.get())) : currentAlias); parsedResult.getRouteContext().getTables().add(table); currentTable = table; } /** * 将表对象加入解析上下文. * * @param x 表名表达式, 来源于FROM, INSERT ,UPDATE, DELETE等语句 */ public Table addTable(final SQLExprTableSource x) { Table result = new Table(SQLUtil.getExactlyValue(x.getExpr().toString()), SQLUtil.getExactlyValue(x.getAlias())); parsedResult.getRouteContext().getTables().add(result); return result; } /** * 向解析上下文中添加条件对象. * * @param expr SQL表达式 * @param operator 操作符 * @param valueExprList 值对象表达式集合 * @param databaseType 数据库类型 * @param parameters 通过占位符传进来的参数 */ public void addCondition(final SQLExpr expr, final BinaryOperator operator, final List<SQLExpr> valueExprList, final DatabaseType databaseType, final List<Object> parameters) { Optional<Column> column = getColumn(expr); if (!column.isPresent()) { return; } if (notShardingColumns(column.get())) { return; } List<ValuePair> values = new ArrayList<>(valueExprList.size()); for (SQLExpr each : valueExprList) { ValuePair evalValue = evalExpression(databaseType, each, parameters); if (null != evalValue) { values.add(evalValue); } } if (values.isEmpty()) { return; } addCondition(column.get(), operator, values); } /** * 将条件对象加入解析上下文. * * @param columnName 列名称 * @param tableName 表名称 * @param operator 操作符 * @param valueExpr 值对象表达式 * @param databaseType 数据库类型 * @param parameters 通过占位符传进来的参数 */ public void addCondition(final String columnName, final String tableName, final BinaryOperator operator, final SQLExpr valueExpr, final DatabaseType databaseType, final List<Object> parameters) { Column column = createColumn(columnName, tableName); if (notShardingColumns(column)) { return; } ValuePair value = evalExpression(databaseType, valueExpr, parameters); if (null != value) { addCondition(column, operator, Collections.singletonList(value)); } } private void addCondition(final Column column, final BinaryOperator operator, final List<ValuePair> valuePairs) { Optional<Condition> optionalCondition = currentConditionContext.find(column.getTableName(), column.getColumnName(), operator); Condition condition; // TODO 待讨论 if (optionalCondition.isPresent()) { condition = optionalCondition.get(); } else { condition = new Condition(column, operator); currentConditionContext.add(condition); } for (ValuePair each : valuePairs) { condition.getValues().add(each.value); if (each.paramIndex > -1) { condition.getValueIndices().add(each.paramIndex); } } } private boolean notShardingColumns(final Column column) { if (!tableShardingColumnsMap.containsKey(column.getTableName())) { tableShardingColumnsMap.putAll(column.getTableName(), shardingRule.getAllShardingColumns(column.getTableName())); } return !tableShardingColumnsMap.containsEntry(column.getTableName(), column.getColumnName()); } private ValuePair evalExpression(final DatabaseType databaseType, final SQLObject sqlObject, final List<Object> parameters) { if (sqlObject instanceof SQLMethodInvokeExpr) { // TODO 解析函数中的sharingValue不支持 return null; } SQLEvalVisitor visitor; switch (databaseType.name().toLowerCase()) { case JdbcUtils.MYSQL: case JdbcUtils.H2: visitor = new MySQLEvalVisitor(); break; default: visitor = SQLEvalVisitorUtils.createEvalVisitor(databaseType.name()); } visitor.setParameters(parameters); sqlObject.accept(visitor); Object value = SQLEvalVisitorUtils.getValue(sqlObject); if (null == value) { // TODO 对于NULL目前解析为空字符串,此处待考虑解决方法 return null; } Comparable<?> finalValue; if (value instanceof Comparable<?>) { finalValue = (Comparable<?>) value; } else { finalValue = ""; } Integer index = (Integer) sqlObject.getAttribute(MySQLEvalVisitor.EVAL_VAR_INDEX); if (null == index) { index = -1; } return new ValuePair(finalValue, index); } private Optional<Column> getColumn(final SQLExpr expr) { if (expr instanceof SQLPropertyExpr) { return Optional.fromNullable(getColumnWithQualifiedName((SQLPropertyExpr) expr)); } if (expr instanceof SQLIdentifierExpr) { return Optional.fromNullable(getColumnWithoutAlias((SQLIdentifierExpr) expr)); } return Optional.absent(); } private Column getColumnWithQualifiedName(final SQLPropertyExpr expr) { Optional<Table> table = findTable(((SQLIdentifierExpr) expr.getOwner()).getName()); return expr.getOwner() instanceof SQLIdentifierExpr && table.isPresent() ? createColumn(expr.getName(), table.get().getName()) : null; } private Column getColumnWithoutAlias(final SQLIdentifierExpr expr) { return null != currentTable ? createColumn(expr.getName(), currentTable.getName()) : null; } private Column createColumn(final String columnName, final String tableName) { return new Column(SQLUtil.getExactlyValue(columnName), SQLUtil.getExactlyValue(tableName)); } private Optional<Table> findTable(final String tableNameOrAlias) { Optional<Table> tableFromName = findTableFromName(tableNameOrAlias); return tableFromName.isPresent() ? tableFromName : findTableFromAlias(tableNameOrAlias); } /** * 判断SQL表达式是否为二元操作且带有别名. * * @param x 待判断的SQL表达式 * @param tableOrAliasName 表名称或别名 * @return 是否为二元操作且带有别名 */ public boolean isBinaryOperateWithAlias(final SQLPropertyExpr x, final String tableOrAliasName) { return x.getParent() instanceof SQLBinaryOpExpr && findTableFromAlias(SQLUtil.getExactlyValue(tableOrAliasName)).isPresent(); } private Optional<Table> findTableFromName(final String name) { for (Table each : parsedResult.getRouteContext().getTables()) { if (each.getName().equalsIgnoreCase(SQLUtil.getExactlyValue(name))) { return Optional.of(each); } } return Optional.absent(); } private Optional<Table> findTableFromAlias(final String alias) { for (Table each : parsedResult.getRouteContext().getTables()) { if (each.getAlias().isPresent() && each.getAlias().get().equalsIgnoreCase(SQLUtil.getExactlyValue(alias))) { return Optional.of(each); } } return Optional.absent(); } /** * 将求平均值函数的补列加入解析上下文. * * @param avgColumn 求平均值的列 */ public void addDerivedColumnsForAvgColumn(final AggregationColumn avgColumn) { addDerivedColumnForAvgColumn(avgColumn, getDerivedCountColumn(avgColumn)); addDerivedColumnForAvgColumn(avgColumn, getDerivedSumColumn(avgColumn)); } private void addDerivedColumnForAvgColumn(final AggregationColumn avgColumn, final AggregationColumn derivedColumn) { avgColumn.getDerivedColumns().add(derivedColumn); parsedResult.getMergeContext().getAggregationColumns().add(derivedColumn); } private AggregationColumn getDerivedCountColumn(final AggregationColumn avgColumn) { String expression = avgColumn.getExpression().replaceFirst(AggregationType.AVG.toString(), AggregationType.COUNT.toString()); return new AggregationColumn(expression, AggregationType.COUNT, Optional.of(generateDerivedColumnAlias()), avgColumn.getOption()); } private String generateDerivedColumnAlias() { return String.format(SHARDING_GEN_ALIAS, ++selectItemsCount); } private AggregationColumn getDerivedSumColumn(final AggregationColumn avgColumn) { String expression = avgColumn.getExpression().replaceFirst(AggregationType.AVG.toString(), AggregationType.SUM.toString()); if (avgColumn.getOption().isPresent()) { expression = expression.replaceFirst(avgColumn.getOption().get() + " ", ""); } return new AggregationColumn(expression, AggregationType.SUM, Optional.of(generateDerivedColumnAlias()), Optional.<String>absent()); } /** * 将排序列加入解析上下文. * * @param index 列顺序索引 * @param orderByType 排序类型 */ public void addOrderByColumn(final int index, final OrderByType orderByType) { parsedResult.getMergeContext().getOrderByColumns().add(new OrderByColumn(index, orderByType)); } /** * 将排序列加入解析上下文. * * @param owner 列拥有者 * @param name 列名称 * @param orderByType 排序类型 */ public void addOrderByColumn(final Optional<String> owner, final String name, final OrderByType orderByType) { String rawName = SQLUtil.getExactlyValue(name); parsedResult.getMergeContext().getOrderByColumns().add(new OrderByColumn(owner, rawName, getAlias(rawName), orderByType)); } private Optional<String> getAlias(final String name) { if (containsSelectItem(name)) { return Optional.absent(); } return Optional.of(generateDerivedColumnAlias()); } private boolean containsSelectItem(final String selectItem) { return hasAllColumn || selectItems.contains(selectItem); } /** * 将分组列加入解析上下文. * * @param owner 列拥有者 * @param name 列名称 * @param orderByType 排序类型 */ public void addGroupByColumns(final Optional<String> owner, final String name, final OrderByType orderByType) { String rawName = SQLUtil.getExactlyValue(name); parsedResult.getMergeContext().getGroupByColumns().add(new GroupByColumn(owner, rawName, getAlias(rawName), orderByType)); } /** * 将当前解析的条件对象归并入解析结果. */ public void mergeCurrentConditionContext() { if (!parsedResult.getRouteContext().getTables().isEmpty()) { if (parsedResult.getConditionContexts().isEmpty()) { parsedResult.getConditionContexts().add(currentConditionContext); } return; } Optional<SQLParsedResult> target = findValidParseResult(); if (!target.isPresent()) { if (parsedResult.getConditionContexts().isEmpty()) { parsedResult.getConditionContexts().add(currentConditionContext); } return; } parsedResult.getRouteContext().getTables().addAll(target.get().getRouteContext().getTables()); parsedResult.getConditionContexts().addAll(target.get().getConditionContexts()); } private Optional<SQLParsedResult> findValidParseResult() { for (ParseContext each : subParseContext) { each.mergeCurrentConditionContext(); if (each.getParsedResult().getRouteContext().getTables().isEmpty()) { continue; } return Optional.of(each.getParsedResult()); } return Optional.absent(); } /** * 注册SELECT语句中声明的列名称或别名. * * @param selectItem SELECT语句中声明的列名称或别名 */ public void registerSelectItem(final String selectItem) { String rawItemExpr = SQLUtil.getExactlyValue(selectItem); if ("*".equals(rawItemExpr)) { hasAllColumn = true; return; } selectItems.add(rawItemExpr); } @RequiredArgsConstructor private static class ValuePair { private final Comparable<?> value; private final Integer paramIndex; } }