package jef.database.jsqlparser;
import java.net.URL;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import javax.persistence.PersistenceException;
import jef.common.log.LogUtil;
import jef.database.DbMetaData;
import jef.database.ORMConfig;
import jef.database.OperateTarget;
import jef.database.dialect.DatabaseDialect;
import jef.database.jdbc.JDBCTarget;
import jef.database.jsqlparser.expression.BinaryExpression;
import jef.database.jsqlparser.expression.Column;
import jef.database.jsqlparser.expression.Function;
import jef.database.jsqlparser.expression.Interval;
import jef.database.jsqlparser.expression.operators.arithmetic.Addition;
import jef.database.jsqlparser.expression.operators.arithmetic.Concat;
import jef.database.jsqlparser.expression.operators.relational.ExpressionList;
import jef.database.jsqlparser.statement.select.Limit;
import jef.database.jsqlparser.statement.select.StartWithExpression;
import jef.database.jsqlparser.visitor.Expression;
import jef.database.jsqlparser.visitor.VisitorAdapter;
import jef.database.meta.DbProperty;
import jef.database.meta.Feature;
import jef.database.meta.FunctionMapping;
import jef.database.query.function.SQLFunction;
/**
* 将函数,字符串相加等逻辑修改为符合当前数据库的格式
*
* @author jiyi
*
*/
public class SqlFunctionlocalization extends VisitorAdapter {
private DatabaseDialect profile;
private JDBCTarget db;
private boolean check;
public StartWithExpression delayStartWith;
public Limit delayLimit;
/**
* 构造
*
* @param dialect
* 数据库简要表
* @param db
* 用于进行UserFunction检查,如果传入null则不进行检查
*/
public SqlFunctionlocalization(DatabaseDialect dialect, JDBCTarget db) {
this.profile = dialect;
this.db = db;
this.check = ORMConfig.getInstance().isCheckSqlFunctions();
}
@Override
public void visit(Concat concat) {
super.visit(concat);// 先处理内层的。。。
if (profile.has(Feature.CONCAT_IS_ADD)) {
concat.rewrite = new Addition(concat.getLeftExpression(), concat.getRightExpression());
} else if (profile.notHas(Feature.SUPPORT_CONCAT)) {
List<Expression> el = new ArrayList<Expression>();
recursion(concat, el);
Function func = new Function();
func.setName("concat");
func.setParameters(new ExpressionList(el));
concat.rewrite = func;
}
}
/**
* Jiyi 2014-10-22添加。 当用户输入的SQL语句中,对于关键字的列没有加上引号时,在不允许对应关键字的数据库上可能会出错,因此检测,
* 如果是关键字那么就加上引号成为合法的列名。
*
* TODO 但是这种修改可能会引起一些非预期的反应。如果解析器错误的将某个不带参数括号的函数当做是列名,则会引起误认,
* 比如将CURRENT_TIMESTAMP误认为是列名而加上引号。 目前尚未观测到此类现象发生。但应进一步测试。
*/
@Override
public void visit(Column tableColumn) {
String s = profile.getProperty(DbProperty.WRAP_FOR_KEYWORD);
if (s != null && profile.containKeyword(tableColumn.getColumnName())) {
Object obj=visitPath.getFirst();
if(obj instanceof ExpressionList){
if(!((ExpressionList) obj).getBetween().equals(",")){
//为了防止将 cast(xx as int)中的int加上引号。
return;
}
}
String columnName=tableColumn.getColumnName();
StringBuilder sb=new StringBuilder(columnName.length()+2);
tableColumn.setColumnName(sb.append(s.charAt(0)).append(columnName).append(s.charAt(1)).toString());
}
}
@Override
public void visit(Function function) {
super.visit(function);// 先处理内层的。。。
String funName = function.getName().toLowerCase();
FunctionMapping mapping = profile.getFunctions().get(funName);// 数据库有这个函数
if (mapping == null) {
jef.database.query.Func func = null;
try {
func = jef.database.query.Func.valueOf(funName);
} catch (IllegalArgumentException e) {
}
;
mapping = profile.getFunctionsByEnum().get(func);
if (mapping == null) {
if (check) {
// 可能是用户自行创建的数据库函数
try {
checkUserFunction(funName);
} catch (SQLException e) {
throw new RuntimeException(e);
}
} else {
return;
}
}
}
mapping.rewrite(function);
}
private void checkUserFunction(String funName) throws SQLException {
if (db == null) {
throw new IllegalArgumentException("database " + profile.getName() + " doesn't support function: " + funName + ".");
}
DbMetaData meta = db.getMetaData();
if (meta==null || meta.checkedFunctions.contains(funName)) {
return;
}
if (meta.existsFunction(null, funName)) {
meta.checkedFunctions.add(funName);
} else {
throw new IllegalArgumentException("database " + profile.getName() + " doesn't support function: " + funName + ".");
}
}
@Override
public void visit(StartWithExpression startWithExpression) {
if (profile.notHas(Feature.SUPPORT_CONNECT_BY)) {
if (super.visitPath.size() <= 2) { // 距离statement最大为2
// 将递归条件保留下来,从而后续支持内存中 递归过滤
delayStartWith = new StartWithExpression(startWithExpression.getStartExpression(), startWithExpression.getConnectExpression());
} else {
if (ORMConfig.getInstance().isAllowRemoveStartWith()) {
String removed = startWithExpression.toString();
LogUtil.warn("[" + removed + "] was removed from your SQL since current db doesn't support it.");
} else {
throw new PersistenceException("The 'START WITH ... CONNECT BY ...' syntax, current db [" + profile.getName() + "] doesn't support!");
}
}
startWithExpression.setStartExpression(null);
startWithExpression.setConnectExpression(null);
}
super.visit(startWithExpression);
}
@Override
public void visit(Limit limit) {
if (profile.notHas(Feature.SUPPORT_LIMIT)) {
if (super.visitPath.size() <= 2) { // 距离statement最大为2
// 将递归条件保留下来,从而后续支持内存中 递归过滤
delayLimit = new Limit(limit);
limit.clear();
}
}
super.visit(limit);
}
public static void ensureUserFunction(FunctionMapping mapping, OperateTarget db) throws SQLException {
DbMetaData meta = db.getMetaData();
boolean flag = true;
for (String name : mapping.requiresUserFunction()) {
if (meta.checkedFunctions.contains(name)) {
continue;
}
meta.checkedFunctions.add(name);
if (!meta.existsFunction(null, name)) {
flag = false;
break;
}
}
if (flag)
return;
SQLFunction sf = mapping.getFunction();
URL url = sf.getClass().getResource(sf.getClass().getSimpleName() + ".sql");
if (url == null) {
// log.warn("Can't find user script file for user function "+ sf);
throw new IllegalArgumentException("Can't find user script file for user function " + sf);
}
try {
meta.executeScriptFile(url);
} catch (SQLException ex) {
throw ex;
}
}
private void recursion(Concat concat, List<Expression> el) {
Expression left = concat.getLeftExpression();
if (left instanceof Concat) {
recursion((Concat) left, el);
} else {
el.add(left);
}
Expression right = concat.getRightExpression();
el.add(right);
}
/**
* 只有PG和MYSQL是支持interval语法的,但是两者的语法也有区别。 Oracle不支持Interval,但是支持以1为一天的小数运算,如
* 1/86400=1秒 1/1440=1分 1/24的小时。 oracle可以用整数运算来实现天数的加减。
* oracle提供了add_months来实现月数的加减, 可以用add_months 12 /-12来实现加减年份
*/
@Override
public void visit(Interval interval) {
super.visit(interval);
Object parent = visitPath.pop();
if (parent instanceof BinaryExpression) {
profile.processIntervalExpression((BinaryExpression) parent, interval);
} else if (parent instanceof ExpressionList) {
Object func = visitPath.getFirst();
if (func instanceof Function) {
profile.processIntervalExpression((Function) func, interval);
}
}
visitPath.push(parent);
}
}