/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.test.util.jdbc;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.hibernate.testing.jdbc.ConnectionProviderDelegate;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.mockito.internal.util.MockUtil;
/**
* This {@link ConnectionProvider} extends any other ConnectionProvider that would be used by default taken the current configuration properties, and it
* intercept the underlying {@link PreparedStatement} method calls.
*
* @author Vlad Mihalcea
*/
public class PreparedStatementSpyConnectionProvider
extends ConnectionProviderDelegate {
private final Map<PreparedStatement, String> preparedStatementMap = new LinkedHashMap<>();
private final List<String> executeStatements = new ArrayList<>();
private final List<String> executeUpdateStatements = new ArrayList<>();
private final List<Connection> acquiredConnections = new ArrayList<>( );
private final List<Connection> releasedConnections = new ArrayList<>( );
public PreparedStatementSpyConnectionProvider() {
}
public PreparedStatementSpyConnectionProvider(ConnectionProvider connectionProvider) {
super( connectionProvider );
}
protected Connection actualConnection() throws SQLException {
return super.getConnection();
}
@Override
public Connection getConnection() throws SQLException {
Connection connection = spy( actualConnection() );
acquiredConnections.add( connection );
return connection;
}
@Override
public void closeConnection(Connection conn) throws SQLException {
acquiredConnections.remove( conn );
releasedConnections.add( conn );
super.closeConnection( conn );
}
@Override
public void stop() {
clear();
super.stop();
}
private Connection spy(Connection connection) {
if ( MockUtil.isMock( connection ) ) {
return connection;
}
Connection connectionSpy = Mockito.spy( connection );
try {
Mockito.doAnswer( invocation -> {
PreparedStatement statement = (PreparedStatement) invocation.callRealMethod();
PreparedStatement statementSpy = Mockito.spy( statement );
String sql = (String) invocation.getArguments()[0];
preparedStatementMap.put( statementSpy, sql );
return statementSpy;
} ).when( connectionSpy ).prepareStatement( ArgumentMatchers.anyString() );
Mockito.doAnswer( invocation -> {
Statement statement = (Statement) invocation.callRealMethod();
Statement statementSpy = Mockito.spy( statement );
Mockito.doAnswer( statementInvocation -> {
String sql = (String) statementInvocation.getArguments()[0];
executeStatements.add( sql );
return statementInvocation.callRealMethod();
} ).when( statementSpy ).execute( ArgumentMatchers.anyString() );
Mockito.doAnswer( statementInvocation -> {
String sql = (String) statementInvocation.getArguments()[0];
executeUpdateStatements.add( sql );
return statementInvocation.callRealMethod();
} ).when( statementSpy ).executeUpdate( ArgumentMatchers.anyString() );
return statementSpy;
} ).when( connectionSpy ).createStatement();
}
catch ( SQLException e ) {
throw new IllegalArgumentException( e );
}
return connectionSpy;
}
/**
* Clears the recorded PreparedStatements and reset the associated Mocks.
*/
public void clear() {
acquiredConnections.clear();
releasedConnections.clear();
preparedStatementMap.keySet().forEach( Mockito::reset );
preparedStatementMap.clear();
}
/**
* Get one and only one PreparedStatement associated to the given SQL statement.
*
* @param sql SQL statement.
*
* @return matching PreparedStatement.
*
* @throws IllegalArgumentException If there is no matching PreparedStatement or multiple instances, an exception is being thrown.
*/
public PreparedStatement getPreparedStatement(String sql) {
List<PreparedStatement> preparedStatements = getPreparedStatements( sql );
if ( preparedStatements.isEmpty() ) {
throw new IllegalArgumentException(
"There is no PreparedStatement for this SQL statement " + sql );
}
else if ( preparedStatements.size() > 1 ) {
throw new IllegalArgumentException( "There are " + preparedStatements
.size() + " PreparedStatements for this SQL statement " + sql );
}
return preparedStatements.get( 0 );
}
/**
* Get the PreparedStatements that are associated to the following SQL statement.
*
* @param sql SQL statement.
*
* @return list of recorded PreparedStatements matching the SQL statement.
*/
public List<PreparedStatement> getPreparedStatements(String sql) {
return preparedStatementMap.entrySet()
.stream()
.filter( entry -> entry.getValue().equals( sql ) )
.map( Map.Entry::getKey )
.collect( Collectors.toList() );
}
/**
* Get the PreparedStatements that were executed since the last clear operation.
*
* @return list of recorded PreparedStatements.
*/
public List<PreparedStatement> getPreparedStatements() {
return new ArrayList<>( preparedStatementMap.keySet() );
}
/**
* Get the SQL statements that were executed since the last clear operation.
* @return list of recorded update statements.
*/
public List<String> getExecuteStatements() {
return executeStatements;
}
/**
* Get the SQL update statements that were executed since the last clear operation.
* @return list of recorded update statements.
*/
public List<String> getExecuteUpdateStatements() {
return executeUpdateStatements;
}
/**
* Get a list of current acquired Connections.
* @return list of current acquired Connections
*/
public List<Connection> getAcquiredConnections() {
return acquiredConnections;
}
/**
* Get a list of current released Connections.
* @return list of current released Connections
*/
public List<Connection> getReleasedConnections() {
return releasedConnections;
}
}