package com.w11k.lsql.statement;
import com.google.common.base.CharMatcher;
import com.google.common.base.Splitter;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.w11k.lsql.LSql;
import com.w11k.lsql.LiteralQueryParameter;
import com.w11k.lsql.QueryParameter;
import com.w11k.lsql.converter.Converter;
import com.w11k.lsql.exceptions.QueryException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Types;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class SqlStatementToPreparedStatement {
// ..... /*name=*/ 123 /**/
private static final Pattern QUERY_ARG_START = Pattern.compile(
"/\\*\\s*(\\S*)\\s*=\\s*\\*/");
private static final String QUERY_ARG_END = "/**/";
public static final class Parameter {
String placeholder;
String name;
int startIndex;
int endIndex;
public String getName() {
return name;
}
public int getStartIndex() {
return startIndex;
}
public int getEndIndex() {
return endIndex;
}
@Override
public String toString() {
return "Parameter{" +
"placeholder='" + placeholder + '\'' +
", name='" + name + '\'' +
", startIndex=" + startIndex +
", endIndex=" + endIndex +
'}';
}
}
private final class ParameterInPreparedStatement {
Parameter parameter;
Object value;
@Override
public String toString() {
return "ParameterInPreparedStatement{" +
"parameter=" + parameter +
", value=" + value +
'}';
}
}
private final Logger logger = LoggerFactory.getLogger(getClass());
private final LSql lSql;
private final String statementName;
private final String sqlString;
private final Map<String, List<Parameter>> parameters;
public SqlStatementToPreparedStatement(LSql lSql, String statementName, String sqlString) {
this.lSql = lSql;
this.statementName = statementName;
this.sqlString = sqlString;
this.parameters = parseParameters();
}
public com.w11k.lsql.LSql getlSql() {
return this.lSql;
}
public String getSqlString() {
return sqlString;
}
public ImmutableMap<String, List<Parameter>> getParameters() {
return ImmutableMap.copyOf(this.parameters);
}
private Map<String, List<Parameter>> parseParameters() {
Map<String, List<Parameter>> found = Maps.newHashMap();
Matcher matcher = QUERY_ARG_START.matcher(sqlString);
while (matcher.find()) {
Parameter p = new Parameter();
// Name
String paramName = matcher.group(1);
paramName = paramName.equals("") ? extractParameterName(sqlString, matcher.start()) : paramName;
p.name = paramName;
// Start
p.startIndex = matcher.start();
// End
int paramEnd = sqlString.indexOf(QUERY_ARG_END, p.startIndex);
if (paramEnd == -1) {
throw new IllegalArgumentException("Unable to find end marker for parameter '" + p.name + "'");
}
paramEnd += QUERY_ARG_END.length();
p.endIndex = paramEnd;
// Placeholder for PreparedStatement
p.placeholder = "?" + Strings.repeat(" ", p.endIndex - p.startIndex - 1);
List<Parameter> parametersForName = found.containsKey(p.name) ? found.get(p.name) : Lists.<Parameter>newLinkedList();
parametersForName.add(p);
found.put(p.name, parametersForName);
}
return found;
}
private String extractParameterName(String sqlString, int start) {
String left = sqlString.substring(0, start);
left = left.trim();
Iterable<String> splitIter = Splitter.on(CharMatcher.anyOf("!=<> ")).omitEmptyStrings().split(left);
ArrayList<String> strings = Lists.newArrayList(splitIter);
String name = strings.get(strings.size() - 1);
if (name.toUpperCase().equals("IS")) {
name = strings.get(strings.size() - 2);
}
return name;
}
private void log(Map<String, Object> queryParameters) {
if (this.logger.isTraceEnabled()) {
ArrayList<String> keys = Lists.newArrayList(queryParameters.keySet());
Collections.sort(keys, new Comparator<String>() {
@Override
public int compare(String o1, String o2) {
return o1.compareToIgnoreCase(o2);
}
});
String msg = "Executing statement '" + this.statementName + "' with parameters:\n";
for (String key : keys) {
msg += String.format("%15s = %s\n", key, queryParameters.get(key));
}
this.logger.trace(msg);
} else if (this.logger.isDebugEnabled()) {
this.logger.debug("Executing statement '{}' with parameters {}", this.statementName, queryParameters.keySet());
}
}
private String processRawConversions(String sql, List<ParameterInPreparedStatement> parameterInPreparedStatements) {
int lastIndex = 0;
String sqlCopy = "";
// Separate iteration because the following iteration will destroy the indexes
for (ParameterInPreparedStatement pips : parameterInPreparedStatements) {
if (pips.value instanceof LiteralQueryParameter) {
LiteralQueryParameter literalQueryParameter = (LiteralQueryParameter) pips.value;
sqlCopy += sql.substring(lastIndex, pips.parameter.startIndex);
sqlCopy += literalQueryParameter.getSqlString();
lastIndex = pips.parameter.endIndex;
}
}
sqlCopy += sql.substring(lastIndex);
return sqlCopy;
}
PreparedStatement createPreparedStatement(Map<String, Object> queryParameters) throws SQLException {
log(queryParameters);
List<ParameterInPreparedStatement> parameterInPreparedStatements = Lists.newLinkedList();
String sqlStringCopy = this.sqlString;
for (String queryParameter : queryParameters.keySet()) {
List<Parameter> parametersByName = this.parameters.get(queryParameter);
if (parametersByName == null) {
throw new QueryException("Unused query parameter: " + queryParameter);
}
for (Parameter p : parametersByName) {
String left = sqlStringCopy.substring(0, p.startIndex);
String right = sqlStringCopy.substring(p.endIndex);
sqlStringCopy = left + p.placeholder + right;
ParameterInPreparedStatement pips = new ParameterInPreparedStatement();
pips.parameter = p;
pips.value = queryParameters.get(queryParameter);
parameterInPreparedStatements.add(pips);
}
}
// sort parameters by their position in the SQL statement
Collections.sort(parameterInPreparedStatements, new Comparator<ParameterInPreparedStatement>() {
public int compare(ParameterInPreparedStatement o1, ParameterInPreparedStatement o2) {
return o1.parameter.startIndex - o2.parameter.startIndex;
}
});
// RAW conversions
sqlStringCopy = processRawConversions(sqlStringCopy, parameterInPreparedStatements);
PreparedStatement ps = this.lSql.getDialect().getStatementCreator().createPreparedStatement(this.lSql, sqlStringCopy, false);
int offset = 0;
for (int i = 0; i < parameterInPreparedStatements.size(); i++) {
ParameterInPreparedStatement pips = parameterInPreparedStatements.get(i);
if (pips.value instanceof QueryParameter) {
QueryParameter queryParameter = (QueryParameter) pips.value;
queryParameter.set(ps, i + 1);
} else if (pips.value instanceof LiteralQueryParameter) {
LiteralQueryParameter dqp = (LiteralQueryParameter) pips.value;
for (int localIndex = 0; localIndex < dqp.getNumberOfQueryParameters(); localIndex++) {
dqp.set(ps, i + 1 + offset + localIndex, localIndex);
}
// -1 because one ? was already set
offset += dqp.getNumberOfQueryParameters() - 1;
} else if (pips.value == null) {
ps.setNull(i + offset + 1, Types.OTHER);
} else {
Converter converter = this.lSql.getDialect().getConverterRegistry()
.getConverterForJavaType(pips.value.getClass());
if (converter == null) {
throw new IllegalArgumentException(this.statementName + ": no registered converter for parameter " + pips);
}
converter.setValueInStatement(this.lSql, ps, i + offset + 1, pips.value);
}
}
return ps;
}
}