/* * Copyright 2015, The Querydsl Team (http://www.querydsl.com/team) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.querydsl.sql; import java.lang.reflect.InvocationTargetException; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import javax.annotation.Nullable; import javax.inject.Provider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.MDC; import com.google.common.collect.ImmutableList; import com.mysema.commons.lang.CloseableIterator; import com.querydsl.core.*; import com.querydsl.core.support.QueryMixin; import com.querydsl.core.types.*; import com.querydsl.core.types.dsl.Expressions; import com.querydsl.core.types.dsl.SimpleExpression; import com.querydsl.core.types.dsl.Wildcard; import com.querydsl.core.util.ResultSetAdapter; /** * {@code AbstractSQLQuery} is the base type for SQL query implementations * * @param <T> result type * @param <Q> concrete subtype * * @author tiwe */ public abstract class AbstractSQLQuery<T, Q extends AbstractSQLQuery<T, Q>> extends ProjectableSQLQuery<T, Q> { protected static final String PARENT_CONTEXT = AbstractSQLQuery.class.getName() + "#PARENT_CONTEXT"; private static final Logger logger = LoggerFactory.getLogger(AbstractSQLQuery.class); private static final QueryFlag rowCountFlag = new QueryFlag(QueryFlag.Position.AFTER_PROJECTION, ", count(*) over() "); @Nullable private Provider<Connection> connProvider; @Nullable private Connection conn; protected SQLListeners listeners; protected boolean useLiterals; private boolean getLastCell; private Object lastCell; private SQLListenerContext parentContext; private StatementOptions statementOptions = StatementOptions.DEFAULT; public AbstractSQLQuery(@Nullable Connection conn, Configuration configuration) { this(conn, configuration, new DefaultQueryMetadata()); } public AbstractSQLQuery(@Nullable Connection conn, Configuration configuration, QueryMetadata metadata) { super(new QueryMixin<Q>(metadata, false), configuration); this.conn = conn; this.listeners = new SQLListeners(configuration.getListeners()); this.useLiterals = configuration.getUseLiterals(); } public AbstractSQLQuery(Provider<Connection> connProvider, Configuration configuration) { this(connProvider, configuration, new DefaultQueryMetadata()); } public AbstractSQLQuery(Provider<Connection> connProvider, Configuration configuration, QueryMetadata metadata) { super(new QueryMixin<Q>(metadata, false), configuration); this.connProvider = connProvider; this.listeners = new SQLListeners(configuration.getListeners()); this.useLiterals = configuration.getUseLiterals(); } /** * Create an alias for the expression * * @param alias alias * @return this as alias */ public SimpleExpression<T> as(String alias) { return Expressions.as(this, alias); } /** * Create an alias for the expression * * @param alias alias * @return this as alias */ @SuppressWarnings("unchecked") public SimpleExpression<T> as(Path<?> alias) { return Expressions.as(this, (Path) alias); } /** * Add a listener * * @param listener listener to add */ public void addListener(SQLListener listener) { listeners.add(listener); } @Override public long fetchCount() { try { return unsafeCount(); } catch (SQLException e) { String error = "Caught " + e.getClass().getName(); logger.error(error, e); throw configuration.translate(e); } } /** * If you use forUpdate() with a backend that uses page or row locks, rows examined by the * query are write-locked until the end of the current transaction. * * Not supported for SQLite and CUBRID * * @return the current object */ public Q forUpdate() { QueryFlag forUpdateFlag = configuration.getTemplates().getForUpdateFlag(); return addFlag(forUpdateFlag); } /** * FOR SHARE causes the rows retrieved by the SELECT statement to be locked as though for update. * * Supported by MySQL, PostgreSQL, SQLServer. * * @return the current object * * @throws QueryException * if the FOR SHARE is not supported. */ public Q forShare() { return forShare(false); } /** * FOR SHARE causes the rows retrieved by the SELECT statement to be locked as though for update. * * Supported by MySQL, PostgreSQL, SQLServer. * * @param fallbackToForUpdate * if the FOR SHARE is not supported and this parameter is <code>true</code>, the * {@link #forUpdate()} functionality will be used. * * @return the current object * * @throws QueryException * if the FOR SHARE is not supported and <i>fallbackToForUpdate</i> is set to * <code>false</code>. */ public Q forShare(boolean fallbackToForUpdate) { SQLTemplates sqlTemplates = configuration.getTemplates(); if (sqlTemplates.isForShareSupported()) { QueryFlag forShareFlag = sqlTemplates.getForShareFlag(); return addFlag(forShareFlag); } if (fallbackToForUpdate) { return forUpdate(); } throw new QueryException("Using forShare() is not supported"); } @Override protected SQLSerializer createSerializer() { SQLSerializer serializer = new SQLSerializer(configuration); serializer.setUseLiterals(useLiterals); return serializer; } @Nullable private <U> U get(ResultSet rs, Expression<?> expr, int i, Class<U> type) throws SQLException { return configuration.get(rs, expr instanceof Path ? (Path<?>) expr : null, i, type); } private void set(PreparedStatement stmt, Path<?> path, int i, Object value) throws SQLException { configuration.set(stmt, path, i, value); } /** * Called to create and start a new SQL Listener context * * @param connection the database connection * @param metadata the meta data for that context * @return the newly started context */ protected SQLListenerContextImpl startContext(Connection connection, QueryMetadata metadata) { SQLListenerContextImpl context = new SQLListenerContextImpl(metadata, connection); if (parentContext != null) { context.setData(PARENT_CONTEXT, parentContext); } listeners.start(context); return context; } /** * Called to make the call back to listeners when an exception happens * * @param context the current context in play * @param e the exception */ protected void onException(SQLListenerContextImpl context, Exception e) { context.setException(e); listeners.exception(context); } /** * Called to end a SQL listener context * * @param context the listener context to end */ protected void endContext(SQLListenerContext context) { listeners.end(context); } /** * Get the results as a JDBC ResultSet * * @param exprs the expression arguments to retrieve * @return results as ResultSet * @deprecated Use @{code select(..)} to define the projection and {@code getResults()} to obtain * the result set */ @Deprecated public ResultSet getResults(Expression<?>... exprs) { if (exprs.length > 0) { queryMixin.setProjection(exprs); } return getResults(); } /** * Get the results as a JDBC ResultSet * * @return results as ResultSet */ public ResultSet getResults() { final SQLListenerContextImpl context = startContext(connection(), queryMixin.getMetadata()); String queryString = null; List<Object> constants = ImmutableList.of(); try { listeners.preRender(context); SQLSerializer serializer = serialize(false); queryString = serializer.toString(); logQuery(queryString, serializer.getConstants()); context.addSQL(queryString); listeners.rendered(context); listeners.notifyQuery(queryMixin.getMetadata()); constants = serializer.getConstants(); listeners.prePrepare(context); final PreparedStatement stmt = getPreparedStatement(queryString); setParameters(stmt, constants, serializer.getConstantPaths(), getMetadata().getParams()); context.addPreparedStatement(stmt); listeners.prepared(context); listeners.preExecute(context); final ResultSet rs = stmt.executeQuery(); listeners.executed(context); return new ResultSetAdapter(rs) { @Override public void close() throws SQLException { try { super.close(); } finally { stmt.close(); reset(); endContext(context); } } }; } catch (SQLException e) { onException(context, e); reset(); endContext(context); throw configuration.translate(queryString, constants, e); } } private PreparedStatement getPreparedStatement(String queryString) throws SQLException { PreparedStatement statement = connection().prepareStatement(queryString); if (statementOptions.getFetchSize() != null) { statement.setFetchSize(statementOptions.getFetchSize()); } if (statementOptions.getMaxFieldSize() != null) { statement.setMaxFieldSize(statementOptions.getMaxFieldSize()); } if (statementOptions.getQueryTimeout() != null) { statement.setQueryTimeout(statementOptions.getQueryTimeout()); } if (statementOptions.getMaxRows() != null) { statement.setMaxRows(statementOptions.getMaxRows()); } return statement; } protected Configuration getConfiguration() { return configuration; } @SuppressWarnings("unchecked") @Override public CloseableIterator<T> iterate() { Expression<T> expr = (Expression<T>) queryMixin.getMetadata().getProjection(); return iterateSingle(queryMixin.getMetadata(), expr); } @SuppressWarnings("unchecked") private CloseableIterator<T> iterateSingle(QueryMetadata metadata, @Nullable final Expression<T> expr) { SQLListenerContextImpl context = startContext(connection(), queryMixin.getMetadata()); String queryString = null; List<Object> constants = ImmutableList.of(); try { listeners.preRender(context); SQLSerializer serializer = serialize(false); queryString = serializer.toString(); logQuery(queryString, serializer.getConstants()); context.addSQL(queryString); listeners.rendered(context); listeners.notifyQuery(queryMixin.getMetadata()); constants = serializer.getConstants(); listeners.prePrepare(context); final PreparedStatement stmt = getPreparedStatement(queryString); setParameters(stmt, constants, serializer.getConstantPaths(), metadata.getParams()); context.addPreparedStatement(stmt); listeners.prepared(context); listeners.preExecute(context); final ResultSet rs = stmt.executeQuery(); listeners.executed(context); if (expr == null) { return new SQLResultIterator<T>(configuration, stmt, rs, listeners, context) { @Override public T produceNext(ResultSet rs) throws Exception { return (T) rs.getObject(1); } }; } else if (expr instanceof FactoryExpression) { return new SQLResultIterator<T>(configuration, stmt, rs, listeners, context) { @Override public T produceNext(ResultSet rs) throws Exception { return newInstance((FactoryExpression<T>) expr, rs, 0); } }; } else if (expr.equals(Wildcard.all)) { return new SQLResultIterator<T>(configuration, stmt, rs, listeners, context) { @Override public T produceNext(ResultSet rs) throws Exception { Object[] rv = new Object[rs.getMetaData().getColumnCount()]; for (int i = 0; i < rv.length; i++) { rv[i] = rs.getObject(i + 1); } return (T) rv; } }; } else { return new SQLResultIterator<T>(configuration, stmt, rs, listeners, context) { @Override public T produceNext(ResultSet rs) throws Exception { return get(rs, expr, 1, expr.getType()); } }; } } catch (SQLException e) { onException(context, e); endContext(context); throw configuration.translate(queryString, constants, e); } catch (RuntimeException e) { logger.error("Caught " + e.getClass().getName() + " for " + queryString); throw e; } finally { reset(); } } @SuppressWarnings("unchecked") @Override public List<T> fetch() { Expression<T> expr = (Expression<T>) queryMixin.getMetadata().getProjection(); SQLListenerContextImpl context = startContext(connection(), queryMixin.getMetadata()); String queryString = null; List<Object> constants = ImmutableList.of(); try { listeners.preRender(context); SQLSerializer serializer = serialize(false); queryString = serializer.toString(); logQuery(queryString, serializer.getConstants()); context.addSQL(queryString); listeners.rendered(context); listeners.notifyQuery(queryMixin.getMetadata()); constants = serializer.getConstants(); listeners.prePrepare(context); final PreparedStatement stmt = getPreparedStatement(queryString); try { setParameters(stmt, constants, serializer.getConstantPaths(), queryMixin.getMetadata().getParams()); context.addPreparedStatement(stmt); listeners.prepared(context); listeners.preExecute(context); final ResultSet rs = stmt.executeQuery(); listeners.executed(context); try { lastCell = null; final List<T> rv = new ArrayList<T>(); if (expr instanceof FactoryExpression) { FactoryExpression<T> fe = (FactoryExpression<T>) expr; while (rs.next()) { if (getLastCell) { lastCell = rs.getObject(fe.getArgs().size() + 1); getLastCell = false; } rv.add(newInstance(fe, rs, 0)); } } else if (expr.equals(Wildcard.all)) { while (rs.next()) { Object[] row = new Object[rs.getMetaData().getColumnCount()]; if (getLastCell) { lastCell = rs.getObject(row.length); getLastCell = false; } for (int i = 0; i < row.length; i++) { row[i] = rs.getObject(i + 1); } rv.add((T) row); } } else { while (rs.next()) { if (getLastCell) { lastCell = rs.getObject(2); getLastCell = false; } rv.add(get(rs, expr, 1, expr.getType())); } } return rv; } catch (IllegalAccessException e) { onException(context, e); throw new QueryException(e); } catch (InvocationTargetException e) { onException(context,e); throw new QueryException(e); } catch (InstantiationException e) { onException(context,e); throw new QueryException(e); } catch (SQLException e) { onException(context,e); throw configuration.translate(queryString, constants, e); } finally { rs.close(); } } finally { stmt.close(); } } catch (SQLException e) { onException(context, e); throw configuration.translate(queryString, constants, e); } finally { endContext(context); reset(); } } @SuppressWarnings("unchecked") @Override public QueryResults<T> fetchResults() { parentContext = startContext(connection(), queryMixin.getMetadata()); Expression<T> expr = (Expression<T>) queryMixin.getMetadata().getProjection(); QueryModifiers originalModifiers = queryMixin.getMetadata().getModifiers(); try { if (configuration.getTemplates().isCountViaAnalytics() && queryMixin.getMetadata().getGroupBy().isEmpty()) { List<T> results; try { queryMixin.addFlag(rowCountFlag); getLastCell = true; results = fetch(); } finally { queryMixin.removeFlag(rowCountFlag); } long total; if (!results.isEmpty()) { if (lastCell instanceof Number) { total = ((Number) lastCell).longValue(); } else { throw new IllegalStateException("Unsupported lastCell instance " + lastCell); } } else { total = fetchCount(); } return new QueryResults<T>(results, originalModifiers, total); } else { queryMixin.setProjection(expr); long total = fetchCount(); if (total > 0) { return new QueryResults<T>(fetch(), originalModifiers, total); } else { return QueryResults.emptyResults(); } } } finally { endContext(parentContext); reset(); getLastCell = false; parentContext = null; } } private <RT> RT newInstance(FactoryExpression<RT> c, ResultSet rs, int offset) throws InstantiationException, IllegalAccessException, InvocationTargetException, SQLException { Object[] args = new Object[c.getArgs().size()]; for (int i = 0; i < args.length; i++) { args[i] = get(rs, c.getArgs().get(i), offset + i + 1, c.getArgs().get(i).getType()); } return c.newInstance(args); } private void reset() { cleanupMDC(); } protected void setParameters(PreparedStatement stmt, List<?> objects, List<Path<?>> constantPaths, Map<ParamExpression<?>, ?> params) { if (objects.size() != constantPaths.size()) { throw new IllegalArgumentException("Expected " + objects.size() + " paths, but got " + constantPaths.size()); } for (int i = 0; i < objects.size(); i++) { Object o = objects.get(i); try { if (o instanceof ParamExpression) { if (!params.containsKey(o)) { throw new ParamNotSetException((ParamExpression<?>) o); } o = params.get(o); } set(stmt, constantPaths.get(i), i + 1, o); } catch (SQLException e) { throw configuration.translate(e); } } } private long unsafeCount() throws SQLException { SQLListenerContextImpl context = startContext(connection(), getMetadata()); String queryString = null; List<Object> constants = ImmutableList.of(); PreparedStatement stmt = null; ResultSet rs = null; try { listeners.preRender(context); SQLSerializer serializer = serialize(true); queryString = serializer.toString(); logQuery(queryString, serializer.getConstants()); context.addSQL(queryString); listeners.rendered(context); constants = serializer.getConstants(); listeners.prePrepare(context); stmt = getPreparedStatement(queryString); setParameters(stmt, constants, serializer.getConstantPaths(), getMetadata().getParams()); context.addPreparedStatement(stmt); listeners.prepared(context); listeners.preExecute(context); rs = stmt.executeQuery(); boolean hasResult = rs.next(); listeners.executed(context); if (hasResult) { return rs.getLong(1); } else { return 0; } } catch (SQLException e) { onException(context, e); throw configuration.translate(queryString, constants, e); } finally { try { if (rs != null) { rs.close(); } } finally { if (stmt != null) { stmt.close(); } } endContext(context); cleanupMDC(); } } protected void logQuery(String queryString, Collection<Object> parameters) { if (logger.isDebugEnabled()) { String normalizedQuery = queryString.replace('\n', ' '); MDC.put(MDC_QUERY, normalizedQuery); MDC.put(MDC_PARAMETERS, String.valueOf(parameters)); logger.debug(normalizedQuery); } } protected void cleanupMDC() { MDC.remove(MDC_QUERY); MDC.remove(MDC_PARAMETERS); } private Connection connection() { if (conn == null) { if (connProvider != null) { conn = connProvider.get(); } else { throw new IllegalStateException("No connection provided"); } } return conn; } /** * Set whether literals are used in SQL strings instead of parameter bindings (default: false) * * <p>Warning: When literals are used, prepared statement won't have any parameter bindings * and also batch statements will only be simulated, but not executed as actual batch statements.</p> * * @param useLiterals true for literals and false for bindings */ public void setUseLiterals(boolean useLiterals) { this.useLiterals = useLiterals; } @Override protected void clone(Q query) { super.clone(query); this.useLiterals = query.useLiterals; this.listeners = new SQLListeners(query.listeners); } @Override public Q clone() { return this.clone(this.conn); } public abstract Q clone(Connection connection); /** * Set the options to be applied to the JDBC statements of this query * * @param statementOptions options to be applied to statements */ public void setStatementOptions(StatementOptions statementOptions) { this.statementOptions = statementOptions; } }