/*
* Copyright 2015, The Querydsl Team (http://www.querydsl.com/team)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.querydsl.sql.dml;
import java.sql.*;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import javax.inject.Provider;
import org.slf4j.Logger;
import org.slf4j.MDC;
import com.google.common.collect.ImmutableList;
import com.querydsl.core.QueryMetadata;
import com.querydsl.core.dml.DMLClause;
import com.querydsl.core.support.QueryBase;
import com.querydsl.core.types.ParamExpression;
import com.querydsl.core.types.ParamNotSetException;
import com.querydsl.core.types.Path;
import com.querydsl.sql.*;
/**
* {@code AbstractSQLClause} is a superclass for SQL based DMLClause implementations
*
* @param <C> concrete subtype
*
* @author tiwe
*/
public abstract class AbstractSQLClause<C extends AbstractSQLClause<C>> implements DMLClause<C> {
protected final Configuration configuration;
protected final SQLListeners listeners;
protected boolean useLiterals;
protected SQLListenerContextImpl context;
@Nullable
private Provider<Connection> connProvider;
@Nullable
private Connection conn;
public AbstractSQLClause(Configuration configuration) {
this.configuration = configuration;
this.listeners = new SQLListeners(configuration.getListeners());
this.useLiterals = configuration.getUseLiterals();
}
public AbstractSQLClause(Configuration configuration, Provider<Connection> connProvider) {
this(configuration);
this.connProvider = connProvider;
}
public AbstractSQLClause(Configuration configuration, Connection conn) {
this(configuration);
this.conn = conn;
}
/**
* Add a listener
*
* @param listener listener to add
*/
public void addListener(SQLListener listener) {
listeners.add(listener);
}
/**
* Clear the internal state of the clause
*/
public abstract void clear();
/**
* Called to create and start a new SQL Listener context
*
* @param connection the database connection
* @param metadata the meta data for that context
* @param entity the entity for that context
* @return the newly started context
*/
protected SQLListenerContextImpl startContext(Connection connection, QueryMetadata metadata, RelationalPath<?> entity) {
SQLListenerContextImpl context = new SQLListenerContextImpl(metadata, connection, entity);
listeners.start(context);
return context;
}
/**
* Called to make the call back to listeners when an exception happens
*
* @param context the current context in play
* @param e the exception
*/
protected void onException(SQLListenerContextImpl context, Exception e) {
context.setException(e);
listeners.exception(context);
}
/**
* Called to end a SQL listener context
*
* @param context the listener context to end
*/
protected void endContext(SQLListenerContextImpl context) {
listeners.end(context);
this.context = null;
}
protected SQLBindings createBindings(QueryMetadata metadata, SQLSerializer serializer) {
String queryString = serializer.toString();
ImmutableList.Builder<Object> args = ImmutableList.builder();
Map<ParamExpression<?>, Object> params = metadata.getParams();
for (Object o : serializer.getConstants()) {
if (o instanceof ParamExpression) {
if (!params.containsKey(o)) {
throw new ParamNotSetException((ParamExpression<?>) o);
}
o = metadata.getParams().get(o);
}
args.add(o);
}
return new SQLBindings(queryString, args.build());
}
protected SQLSerializer createSerializer() {
SQLSerializer serializer = new SQLSerializer(configuration, true);
serializer.setUseLiterals(useLiterals);
return serializer;
}
/**
* Get the SQL string and bindings
*
* @return SQL and bindings
*/
public abstract List<SQLBindings> getSQL();
/**
* Set the parameters to the given PreparedStatement
*
* @param stmt preparedStatement to be populated
* @param objects list of constants
* @param constantPaths list of paths related to the constants
* @param params map of param to value for param resolving
*/
protected void setParameters(PreparedStatement stmt, List<?> objects,
List<Path<?>> constantPaths, Map<ParamExpression<?>, ?> params) {
if (objects.size() != constantPaths.size()) {
throw new IllegalArgumentException("Expected " + objects.size() + " paths, " +
"but got " + constantPaths.size());
}
for (int i = 0; i < objects.size(); i++) {
Object o = objects.get(i);
try {
if (o instanceof ParamExpression) {
if (!params.containsKey(o)) {
throw new ParamNotSetException((ParamExpression<?>) o);
}
o = params.get(o);
}
configuration.set(stmt, constantPaths.get(i), i + 1, o);
} catch (SQLException e) {
throw configuration.translate(e);
}
}
}
private long executeBatch(PreparedStatement stmt) throws SQLException {
if (configuration.getUseLiterals()) {
return stmt.executeUpdate();
} else if (configuration.getTemplates().isBatchCountViaGetUpdateCount()) {
stmt.executeBatch();
return stmt.getUpdateCount();
} else {
long rv = 0;
for (int i : stmt.executeBatch()) {
rv += i;
}
return rv;
}
}
protected long executeBatch(Collection<PreparedStatement> stmts) throws SQLException {
long rv = 0;
for (PreparedStatement stmt : stmts) {
rv += executeBatch(stmt);
}
return rv;
}
protected void close(Statement stmt) {
try {
stmt.close();
} catch (SQLException e) {
throw configuration.translate(e);
}
}
protected void close(Collection<? extends Statement> stmts) {
for (Statement stmt : stmts) {
close(stmt);
}
}
protected void close(ResultSet rs) {
try {
rs.close();
} catch (SQLException e) {
throw configuration.translate(e);
}
}
protected void logQuery(Logger logger, String queryString, Collection<Object> parameters) {
if (logger.isDebugEnabled()) {
String normalizedQuery = queryString.replace('\n', ' ');
MDC.put(QueryBase.MDC_QUERY, normalizedQuery);
MDC.put(QueryBase.MDC_PARAMETERS, String.valueOf(parameters));
logger.debug(normalizedQuery);
}
}
protected void cleanupMDC() {
MDC.remove(QueryBase.MDC_QUERY);
MDC.remove(QueryBase.MDC_PARAMETERS);
}
protected void reset() {
cleanupMDC();
}
protected Connection connection() {
if (conn == null) {
if (connProvider != null) {
conn = connProvider.get();
} else {
throw new IllegalStateException("No connection provided");
}
}
return conn;
}
public void setUseLiterals(boolean useLiterals) {
this.useLiterals = useLiterals;
}
public abstract int getBatchCount();
}