/*
* Copyright 1999-2017 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.druid.sql.oracle.demo;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import junit.framework.TestCase;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLNumericLiteralExpr;
import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr;
import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter;
import com.alibaba.druid.sql.parser.SQLStatementParser;
public class Demo3 extends TestCase {
public void test_0() throws Exception {
String sql = "select * from user u where u.uid = 2 and uname = ?";
List<Object> parameters = new ArrayList<Object>();
parameters.add(1);
parameters.add("wenshao");
String realSql = convert(sql, parameters);
System.out.println(realSql);
}
public void test_1() throws Exception {
String sql = "select * from user where uid = ? and uname = ?";
List<Object> parameters = new ArrayList<Object>();
parameters.add(1);
parameters.add("wenshao");
String realSql = convert(sql, parameters);
System.out.println(realSql);
}
public void test_2() throws Exception {
String sql = "select * from (select * from user where uid = ? and uname = ?) t";
List<Object> parameters = new ArrayList<Object>();
parameters.add(1);
parameters.add("wenshao");
String realSql = convert(sql, parameters);
System.out.println(realSql);
}
public void test_3() throws Exception {
String sql = "select * from groups where uid = ? and uname = ?";
List<Object> parameters = new ArrayList<Object>();
parameters.add(1);
parameters.add("wenshao");
String realSql = convert(sql, parameters);
System.out.println(realSql);
}
private String convert(String sql, List<Object> parameters) {
SQLStatementParser parser = new MySqlStatementParser(sql);
List<SQLStatement> stmtList = parser.parseStatementList(); //
SQLStatement first = (SQLStatement) stmtList.get(0);
MyVisitor visitor = new MyVisitor();
first.accept(visitor);
if (visitor.getVariantList().size() > 0) {
SQLExpr firstVar = visitor.getVariantList().get(0);
int userId;
if (firstVar instanceof SQLVariantRefExpr) {
int varIndex = (Integer) firstVar.getAttribute("varIndex");
userId = (Integer) parameters.get(varIndex);
} else {
userId = ((SQLNumericLiteralExpr) firstVar).getNumber().intValue();
}
String tableName;
if (userId == 1) {
tableName = "user_1";
} else {
tableName = "user_x";
}
for (SQLExprTableSource tableSource : visitor.getTableSourceList()) {
SQLExpr expr = tableSource.getExpr();
if (expr instanceof SQLIdentifierExpr) {
SQLIdentifierExpr identExpr = (SQLIdentifierExpr) expr;
String ident = identExpr.getName();
if (ident.equals("user")) {
identExpr.setName(tableName);
}
} else if (expr instanceof SQLPropertyExpr) {
SQLPropertyExpr proExpr = (SQLPropertyExpr) expr;
String ident = proExpr.getName();
if (ident.equals("user")) {
proExpr.setName(tableName);
}
}
}
}
String realSql = SQLUtils.toOracleString(first);
return realSql;
}
private static class MyVisitor extends MySqlASTVisitorAdapter {
private int varIndex = 0;
private List<SQLExpr> variantList = new ArrayList<SQLExpr>();
private List<SQLExprTableSource> tableSourceList = new ArrayList<SQLExprTableSource>();
private Map<String, String> tableAlias = new HashMap<String, String>();
private String defaultTableName;
public boolean visit(SQLVariantRefExpr x) {
x.getAttributes().put("varIndex", varIndex++);
return true;
}
public boolean visit(SQLBinaryOpExpr x) {
if (isUserId(x.getLeft())) {
if (x.getRight() instanceof SQLVariantRefExpr) {
SQLIdentifierExpr identExpr = (SQLIdentifierExpr) x.getLeft();
String ident = identExpr.getName();
if (ident.equals("uid")) {
variantList.add(x.getRight());
}
} else if (x.getRight() instanceof SQLNumericLiteralExpr) {
variantList.add(x.getRight());
}
}
return true;
}
private boolean isUserId(SQLExpr x) {
if (x instanceof SQLIdentifierExpr) {
if ("user".equals(defaultTableName) && "uid".equals(((SQLIdentifierExpr) x).getName())) {
return true;
}
return false;
}
if (x instanceof SQLPropertyExpr) {
SQLPropertyExpr propExpr = (SQLPropertyExpr) x;
String columnName = propExpr.getName();
if (!"uid".equals(columnName)) {
return false;
}
if (propExpr.getOwner() instanceof SQLIdentifierExpr) {
String ownerName = ((SQLIdentifierExpr) propExpr.getOwner()).getName();
if ("user".equals(ownerName) || "user".equals(tableAlias.get(ownerName))) {
return true;
}
}
}
return false;
}
public boolean visit(SQLExprTableSource x) {
recordTableSource(x);
return true;
}
private String recordTableSource(SQLExprTableSource x) {
if (x.getExpr() instanceof SQLIdentifierExpr) {
String tableName = ((SQLIdentifierExpr) x.getExpr()).getName();
if (x.getAlias() != null) {
tableAlias.put(x.getAlias(), tableName);
}
if ("user".equals(tableName)) {
if (!tableSourceList.contains(x)) {
tableSourceList.add(x);
}
}
return tableName;
}
return null;
}
public boolean visit(SQLSelectQueryBlock queryBlock) {
if (queryBlock.getFrom() instanceof SQLExprTableSource) {
defaultTableName = recordTableSource((SQLExprTableSource) queryBlock.getFrom());
}
return true;
}
public boolean visit(MySqlSelectQueryBlock queryBlock) {
if (queryBlock.getFrom() instanceof SQLExprTableSource) {
defaultTableName = recordTableSource((SQLExprTableSource) queryBlock.getFrom());
}
return true;
}
public List<SQLExpr> getVariantList() {
return variantList;
}
public List<SQLExprTableSource> getTableSourceList() {
return tableSourceList;
}
}
}