/* * Copyright 1999-2015 dangdang.com. * <p> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * </p> */ package com.dangdang.ddframe.rdb.sharding.parser.visitor.basic.mysql; import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.expr.SQLCharExpr; import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr; import com.alibaba.druid.sql.ast.expr.SQLNumberExpr; import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; import com.dangdang.ddframe.rdb.sharding.api.rule.TableRule; import com.dangdang.ddframe.rdb.sharding.parser.result.GeneratedKeyContext; import com.dangdang.ddframe.rdb.sharding.parser.result.router.Condition.BinaryOperator; import com.dangdang.ddframe.rdb.sharding.util.SQLUtil; import com.google.common.base.Optional; import java.util.Collection; import java.util.List; /** * MySQL的INSERT语句访问器. * * @author gaohongtao * @author zhangliang */ public class MySQLInsertVisitor extends AbstractMySQLVisitor { @Override public boolean visit(final MySqlInsertStatement x) { final String tableName = SQLUtil.getExactlyValue(x.getTableName().toString()); getParseContext().setCurrentTable(tableName, Optional.fromNullable(x.getAlias())); if (null == x.getValues()) { return super.visit(x); } Collection<String> autoIncrementColumns = getParseContext().getShardingRule().getAutoIncrementColumns(tableName); List<SQLExpr> columns = x.getColumns(); List<SQLExpr> values = x.getValues().getValues(); for (int i = 0; i < x.getColumns().size(); i++) { String columnName = SQLUtil.getExactlyValue(columns.get(i).toString()); getParseContext().addCondition(columnName, tableName, BinaryOperator.EQUAL, values.get(i), getDatabaseType(), getParameters()); if (autoIncrementColumns.contains(columnName)) { autoIncrementColumns.remove(columnName); } } if (autoIncrementColumns.isEmpty()) { return super.visit(x); } supplyAutoIncrementColumn(autoIncrementColumns, tableName, columns, values); return super.visit(x); } private void supplyAutoIncrementColumn(final Collection<String> autoIncrementColumns, final String tableName, final List<SQLExpr> columns, final List<SQLExpr> values) { boolean isPreparedStatement = !getParameters().isEmpty(); GeneratedKeyContext generatedKeyContext = getParseContext().getParsedResult().getGeneratedKeyContext(); if (isPreparedStatement) { generatedKeyContext.getColumns().addAll(autoIncrementColumns); } TableRule tableRule = getParseContext().getShardingRule().findTableRule(tableName); for (String each : autoIncrementColumns) { SQLExpr sqlExpr; Object id = tableRule.generateId(each); generatedKeyContext.putValue(each, id); if (isPreparedStatement) { sqlExpr = new SQLVariantRefExpr("?"); getParameters().add(id); ((SQLVariantRefExpr) sqlExpr).setIndex(getParametersSize() - 1); } else { sqlExpr = (id instanceof Number) ? new SQLNumberExpr((Number) id) : new SQLCharExpr((String) id); } getParseContext().addCondition(each, tableName, BinaryOperator.EQUAL, sqlExpr, getDatabaseType(), getParameters()); columns.add(new SQLIdentifierExpr(each)); values.add(sqlExpr); } } }