/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.skife.jdbi.v2;
import static org.skife.jdbi.rewriter.colon.ColonStatementLexer.DOUBLE_QUOTED_TEXT;
import static org.skife.jdbi.rewriter.colon.ColonStatementLexer.ESCAPED_TEXT;
import static org.skife.jdbi.rewriter.colon.ColonStatementLexer.LITERAL;
import static org.skife.jdbi.rewriter.colon.ColonStatementLexer.NAMED_PARAM;
import static org.skife.jdbi.rewriter.colon.ColonStatementLexer.POSITIONAL_PARAM;
import static org.skife.jdbi.rewriter.colon.ColonStatementLexer.QUOTED_TEXT;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.skife.jdbi.org.antlr.runtime.ANTLRStringStream;
import org.skife.jdbi.org.antlr.runtime.Token;
import org.skife.jdbi.rewriter.colon.ColonStatementLexer;
import org.skife.jdbi.v2.exceptions.UnableToCreateStatementException;
import org.skife.jdbi.v2.exceptions.UnableToExecuteStatementException;
import org.skife.jdbi.v2.tweak.Argument;
import org.skife.jdbi.v2.tweak.RewrittenStatement;
import org.skife.jdbi.v2.tweak.StatementRewriter;
import com.google.common.base.Strings;
/**
* <p>
* Statement rewriter which replaces named parameter tokens of the form :tokenName
* </p>
* This is the default statement rewriter.
*
* This is a copy of the {@link ColonPrefixNamedParamStatementRewriter} with sql expansion for
* multi-value argument.
*
* TODO: find sometime and send a pull request to include these changes. Here we are accessing to
* some packages protected classes, which isn't good.
*/
public class ExpandedStmtRewriter implements StatementRewriter
{
/**
* Munge up the SQL as desired. Responsible for figuring out ow to bind any
* arguments in to the resultant prepared statement.
*
* @param sql The SQL to rewrite
* @param params contains the arguments which have been bound to this statement.
* @param ctx The statement context for the statement being executed
* @return somethign which can provde the actual SQL to prepare a statement from
* and which can bind the correct arguments to that prepared statement
*/
@Override
public RewrittenStatement rewrite(final String sql, final Binding params,
final StatementContext ctx)
{
final ParsedStatement stmt = new ParsedStatement();
try {
final String parsedSql = parseString(sql, stmt, params);
return new MyRewrittenStatement(parsedSql, stmt, ctx);
} catch (IllegalArgumentException e) {
throw new UnableToCreateStatementException(
"Exception parsing for named parameter replacement", e, ctx);
}
}
String parseString(final String sql, final ParsedStatement stmt, final Binding params)
throws IllegalArgumentException {
StringBuilder b = new StringBuilder();
ColonStatementLexer lexer = new ColonStatementLexer(new ANTLRStringStream(sql));
Token t = lexer.nextToken();
int pos = 0;
while (t.getType() != ColonStatementLexer.EOF) {
switch (t.getType()) {
case LITERAL:
b.append(t.getText());
break;
case NAMED_PARAM:
String pname = t.getText().substring(1, t.getText().length());
stmt.addNamedParamAt(pname);
Argument arg = params.forName(pname);
if (arg instanceof IterableArgument) {
// expand iterable
int size = ((IterableArgument) arg).size();
b.append(Strings.repeat("?, ", size));
b.setLength(b.length() - 2);
} else {
b.append("?");
}
break;
case QUOTED_TEXT:
b.append(t.getText());
break;
case DOUBLE_QUOTED_TEXT:
b.append(t.getText());
break;
case POSITIONAL_PARAM:
Argument posarg = params.forPosition(pos);
if (posarg instanceof IterableArgument) {
// expand iterable
int size = ((IterableArgument) posarg).size();
b.append(Strings.repeat("?, ", size));
b.setLength(b.length() - 2);
pos += size;
} else {
b.append("?");
pos += 1;
}
stmt.addPositionalParamAt();
break;
case ESCAPED_TEXT:
b.append(t.getText().substring(1));
break;
default:
break;
}
t = lexer.nextToken();
}
return b.toString();
}
private static class MyRewrittenStatement implements RewrittenStatement
{
private final String sql;
private final ParsedStatement stmt;
private final StatementContext context;
public MyRewrittenStatement(final String sql, final ParsedStatement stmt,
final StatementContext ctx)
{
this.context = ctx;
this.sql = sql;
this.stmt = stmt;
}
@Override
public void bind(final Binding params, final PreparedStatement statement) throws SQLException
{
if (stmt.positionalOnly) {
// no named params, is easy
boolean finished = false;
int i = 0, p = 0;
while (!finished) {
final Argument a = params.forPosition(p);
if (a != null) {
try {
this.context.setAttribute("position", null);
a.apply(i + 1, statement, this.context);
Integer pos = (Integer) this.context.getAttribute("position");
i += Optional.ofNullable(pos).orElse(1);
p += 1;
this.context.setAttribute("position", null);
} catch (SQLException e) {
throw new UnableToExecuteStatementException(
String.format(
"Exception while binding positional param at (0 based) position %d",
i), e, context);
}
}
else {
finished = true;
}
}
}
else {
// List<String> named_params = stmt.params;
int i = 0;
for (String named_param : stmt.params) {
if ("*".equals(named_param)) {
continue;
}
Argument a = params.forName(named_param);
if (a == null) {
a = params.forPosition(i);
}
if (a == null) {
String msg = String.format("Unable to execute, no named parameter matches " +
"\"%s\" and no positional param for place %d (which is %d in " +
"the JDBC 'start at 1' scheme) has been set.",
named_param, i, i + 1);
throw new UnableToExecuteStatementException(msg, context);
}
try {
this.context.setAttribute("position", null);
a.apply(i + 1, statement, this.context);
Integer pos = (Integer) this.context.getAttribute("position");
i += Optional.ofNullable(pos).orElse(1);
} catch (SQLException e) {
throw new UnableToCreateStatementException(String.format(
"Exception while binding '%s'",
named_param), e, context);
}
}
}
}
@Override
public String getSql()
{
return sql;
}
}
static class ParsedStatement
{
private boolean positionalOnly = true;
private List<String> params = new ArrayList<String>();
public void addNamedParamAt(final String name)
{
positionalOnly = false;
params.add(name);
}
public void addPositionalParamAt()
{
params.add("*");
}
}
}