/*
* Copyright (c) 2002-2017 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neo4j.driver.internal;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.InOrder;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.mockito.verification.VerificationMode;
import java.util.Map;
import org.neo4j.driver.internal.retry.FixedRetryLogic;
import org.neo4j.driver.internal.retry.RetryLogic;
import org.neo4j.driver.internal.spi.Collector;
import org.neo4j.driver.internal.spi.ConnectionProvider;
import org.neo4j.driver.internal.spi.PooledConnection;
import org.neo4j.driver.internal.util.Supplier;
import org.neo4j.driver.v1.AccessMode;
import org.neo4j.driver.v1.Session;
import org.neo4j.driver.v1.Transaction;
import org.neo4j.driver.v1.TransactionWork;
import org.neo4j.driver.v1.Value;
import org.neo4j.driver.v1.exceptions.ClientException;
import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;
import org.neo4j.driver.v1.exceptions.SessionExpiredException;
import static java.util.Collections.singletonMap;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.RETURNS_MOCKS;
import static org.mockito.Mockito.anyMapOf;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING;
import static org.neo4j.driver.internal.spi.Collector.NO_OP;
import static org.neo4j.driver.v1.AccessMode.READ;
import static org.neo4j.driver.v1.AccessMode.WRITE;
import static org.neo4j.driver.v1.Values.value;
public class NetworkSessionTest
{
@Rule
public ExpectedException exception = ExpectedException.none();
private PooledConnection connection;
private NetworkSession session;
@Before
public void setUp() throws Exception
{
connection = mock( PooledConnection.class );
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
when( connectionProvider.acquireConnection( any( AccessMode.class ) ) ).thenReturn( connection );
session = newSession( connectionProvider, READ );
}
@Test
public void shouldSendAllOnRun() throws Throwable
{
// Given
when( connection.isOpen() ).thenReturn( true );
// When
session.run( "whatever" );
// Then
verify( connection ).flush();
}
@Test
public void shouldNotAllowNewTxWhileOneIsRunning() throws Throwable
{
// Given
when( connection.isOpen() ).thenReturn( true );
session.beginTransaction();
// Expect
exception.expect( ClientException.class );
// When
session.beginTransaction();
}
@Test
public void shouldBeAbleToOpenTxAfterPreviousIsClosed() throws Throwable
{
// Given
when( connection.isOpen() ).thenReturn( true );
session.beginTransaction().close();
// When
Transaction tx = session.beginTransaction();
// Then we should've gotten a transaction object back
assertNotNull( tx );
}
@Test
public void shouldNotBeAbleToUseSessionWhileOngoingTransaction() throws Throwable
{
// Given
when( connection.isOpen() ).thenReturn( true );
session.beginTransaction();
// Expect
exception.expect( ClientException.class );
// When
session.run( "whatever" );
}
@Test
public void shouldBeAbleToUseSessionAgainWhenTransactionIsClosed() throws Throwable
{
// Given
when( connection.isOpen() ).thenReturn( true );
session.beginTransaction().close();
// When
session.run( "whatever" );
// Then
verify( connection ).flush();
}
@Test
public void shouldGetExceptionIfTryingToCloseSessionMoreThanOnce() throws Throwable
{
// Given
ConnectionProvider connectionProvider = mock( ConnectionProvider.class, RETURNS_MOCKS );
NetworkSession sess = newSession( connectionProvider, READ );
try
{
sess.close();
}
catch ( Exception e )
{
fail( "Should not get any problem to close first time" );
}
// When
try
{
sess.close();
fail( "Should have received an error to close second time" );
}
catch ( Exception e )
{
assertThat( e.getMessage(), equalTo( "This session has already been closed." ) );
}
}
@Test
public void runThrowsWhenSessionIsClosed()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class, RETURNS_MOCKS );
NetworkSession session = newSession( connectionProvider, READ );
session.close();
try
{
session.run( "CREATE ()" );
fail( "Exception expected" );
}
catch ( Exception e )
{
assertThat( e, instanceOf( ClientException.class ) );
assertThat( e.getMessage(), containsString( "session is already closed" ) );
}
}
@Test
public void acquiresNewConnectionForRun()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = mock( PooledConnection.class );
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
session.run( "RETURN 1" );
verify( connectionProvider ).acquireConnection( READ );
verify( connection ).run( eq( "RETURN 1" ), anyParams(), any( Collector.class ) );
}
@Test
public void syncsAndClosesPreviousConnectionForRun()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection1 = openConnectionMock();
PooledConnection connection2 = openConnectionMock();
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection1 ).thenReturn( connection2 );
NetworkSession session = newSession( connectionProvider, READ );
session.run( "RETURN 1" );
verify( connectionProvider ).acquireConnection( READ );
verify( connection1 ).run( eq( "RETURN 1" ), anyParams(), any( Collector.class ) );
session.run( "RETURN 2" );
verify( connectionProvider, times( 2 ) ).acquireConnection( READ );
verify( connection2 ).run( eq( "RETURN 2" ), anyParams(), any( Collector.class ) );
InOrder inOrder = inOrder( connection1 );
inOrder.verify( connection1 ).sync();
inOrder.verify( connection1 ).close();
}
@Test
public void closesPreviousBrokenConnectionForRun()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection1 = mock( PooledConnection.class );
PooledConnection connection2 = mock( PooledConnection.class );
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection1 ).thenReturn( connection2 );
NetworkSession session = newSession( connectionProvider, READ );
session.run( "RETURN 1" );
verify( connectionProvider ).acquireConnection( READ );
verify( connection1 ).run( eq( "RETURN 1" ), anyParams(), any( Collector.class ) );
session.run( "RETURN 2" );
verify( connectionProvider, times( 2 ) ).acquireConnection( READ );
verify( connection2 ).run( eq( "RETURN 2" ), anyParams(), any( Collector.class ) );
verify( connection1, never() ).sync();
verify( connection1 ).close();
}
@Test
public void closesAndSyncOpenConnectionUsedForRunWhenSessionIsClosed()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
session.run( "RETURN 1" );
verify( connectionProvider ).acquireConnection( READ );
verify( connection ).run( eq( "RETURN 1" ), anyParams(), any( Collector.class ) );
session.close();
verify( connection ).sync();
verify( connection ).close();
}
@Test
public void closesClosedConnectionUsedForRunWhenSessionIsClosed()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = mock( PooledConnection.class );
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
session.run( "RETURN 1" );
verify( connectionProvider ).acquireConnection( READ );
verify( connection ).run( eq( "RETURN 1" ), anyParams(), any( Collector.class ) );
session.close();
verify( connection, never() ).sync();
verify( connection ).close();
}
@SuppressWarnings( "deprecation" )
@Test
public void resetDoesNothingWhenNoTransactionAndNoConnection()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
NetworkSession session = newSession( connectionProvider, READ );
session.reset();
verify( connectionProvider, never() ).acquireConnection( any( AccessMode.class ) );
}
@Test
public void closeWithoutConnection()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
NetworkSession session = newSession( connectionProvider, READ );
session.close();
verify( connectionProvider, never() ).acquireConnection( any( AccessMode.class ) );
}
@Test
public void acquiresNewConnectionForBeginTx()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = mock( PooledConnection.class );
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
session.beginTransaction();
verify( connectionProvider ).acquireConnection( READ );
}
@Test
public void closesPreviousConnectionForBeginTx()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection1 = mock( PooledConnection.class );
PooledConnection connection2 = mock( PooledConnection.class );
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection1 ).thenReturn( connection2 );
NetworkSession session = newSession( connectionProvider, READ );
session.run( "RETURN 1" );
verify( connectionProvider ).acquireConnection( READ );
verify( connection1 ).run( eq( "RETURN 1" ), anyParams(), any( Collector.class ) );
session.beginTransaction();
verify( connection1 ).close();
verify( connectionProvider, times( 2 ) ).acquireConnection( READ );
}
@Test
public void updatesBookmarkWhenTxIsClosed()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
Transaction tx = session.beginTransaction();
setBookmark( tx, "TheBookmark" );
assertNull( session.lastBookmark() );
tx.close();
assertEquals( "TheBookmark", session.lastBookmark() );
}
@Test
public void closesConnectionWhenTxIsClosed()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
Transaction tx = session.beginTransaction();
tx.run( "RETURN 1" );
verify( connectionProvider ).acquireConnection( READ );
verify( connection ).run( eq( "RETURN 1" ), anyParams(), any( Collector.class ) );
tx.close();
verify( connection ).sync();
verify( connection ).close();
}
@Test
public void ignoresWronglyClosedTx()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection1 = openConnectionMock();
PooledConnection connection2 = openConnectionMock();
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection1 ).thenReturn( connection2 );
NetworkSession session = newSession( connectionProvider, READ );
Transaction tx1 = session.beginTransaction();
tx1.close();
Transaction tx2 = session.beginTransaction();
tx2.close();
tx1.close();
verify( connection1 ).sync();
verify( connection1 ).close();
verify( connection2 ).sync();
verify( connection2 ).close();
}
@Test
public void ignoresWronglyClosedTxWhenAnotherTxInProgress()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection1 = openConnectionMock();
PooledConnection connection2 = openConnectionMock();
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection1 ).thenReturn( connection2 );
NetworkSession session = newSession( connectionProvider, READ );
Transaction tx1 = session.beginTransaction();
tx1.close();
Transaction tx2 = session.beginTransaction();
tx1.close();
tx2.close();
verify( connection1 ).sync();
verify( connection1 ).close();
verify( connection2 ).sync();
verify( connection2 ).close();
}
@Test
public void transactionClosedDoesNothingWhenNoTx()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = mock( PooledConnection.class );
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
session.onTransactionClosed( mock( ExplicitTransaction.class ) );
verifyZeroInteractions( connection );
}
@Test
public void transactionClosedIgnoresWrongTx()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = mock( PooledConnection.class );
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
session.beginTransaction();
verify( connectionProvider ).acquireConnection( READ );
ExplicitTransaction wrongTx = mock( ExplicitTransaction.class );
session.onTransactionClosed( wrongTx );
verify( connection, never() ).close();
}
@Test
public void markTxAsFailedOnRecoverableConnectionError()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
Transaction tx = session.beginTransaction();
assertTrue( tx.isOpen() );
session.onConnectionError( true );
assertFalse( tx.isOpen() );
}
@Test
public void markTxToCloseOnUnrecoverableConnectionError()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
Transaction tx = session.beginTransaction();
assertTrue( tx.isOpen() );
session.onConnectionError( false );
assertFalse( tx.isOpen() );
}
@Test
public void closesConnectionWhenResultIsBuffered()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
session.run( "RETURN 1" );
verify( connectionProvider ).acquireConnection( READ );
verify( connection ).run( eq( "RETURN 1" ), anyParams(), any( Collector.class ) );
session.onResultConsumed();
verify( connection, never() ).sync();
verify( connection ).close();
}
@Test
public void bookmarkIsPropagatedFromSession()
{
String bookmark = "Bookmark";
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = mock( PooledConnection.class );
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ, bookmark );
try ( Transaction ignore = session.beginTransaction() )
{
verifyBeginTx( connection, bookmark );
}
}
@Test
public void bookmarkIsPropagatedInBeginTransaction()
{
String bookmark = "Bookmark";
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = mock( PooledConnection.class );
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
session.setBookmark(bookmark);
try ( Transaction ignore = session.beginTransaction() )
{
verifyBeginTx( connection, bookmark );
}
}
@Test
public void bookmarkIsPropagatedBetweenTransactions()
{
String bookmark1 = "Bookmark1";
String bookmark2 = "Bookmark2";
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = mock( PooledConnection.class );
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
try ( Transaction tx = session.beginTransaction() )
{
setBookmark( tx, bookmark1 );
}
assertEquals( bookmark1, session.lastBookmark() );
try ( Transaction tx = session.beginTransaction() )
{
verifyBeginTx( connection, bookmark1 );
assertNull( getBookmark( tx ) );
setBookmark( tx, bookmark2 );
}
assertEquals( bookmark2, session.lastBookmark() );
}
@Test
public void accessModeUsedToAcquireConnections()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class, RETURNS_MOCKS );
NetworkSession session1 = newSession( connectionProvider, READ );
session1.beginTransaction();
verify( connectionProvider ).acquireConnection( READ );
NetworkSession session2 = newSession( connectionProvider, WRITE );
session2.beginTransaction();
verify( connectionProvider ).acquireConnection( WRITE );
}
@Test
public void setLastBookmark()
{
NetworkSession session = newSession( mock( ConnectionProvider.class ), WRITE );
session.setBookmark( "TheBookmark" );
assertEquals( "TheBookmark", session.lastBookmark() );
}
@Test
public void testPassingNoBookmarkShouldRetainBookmark()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
session.setBookmark( "X" );
session.beginTransaction();
assertThat( session.lastBookmark(), equalTo( "X" ) );
}
@SuppressWarnings( "deprecation" )
@Test
public void testPassingNullBookmarkShouldRetainBookmark()
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( READ ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, READ );
session.setBookmark( "X" );
session.beginTransaction( null );
assertThat( session.lastBookmark(), equalTo( "X" ) );
}
@Test
public void acquiresReadConnectionForReadTxInReadSession()
{
testConnectionAcquisition( READ, READ );
}
@Test
public void acquiresWriteConnectionForWriteTxInReadSession()
{
testConnectionAcquisition( READ, WRITE );
}
@Test
public void acquiresReadConnectionForReadTxInWriteSession()
{
testConnectionAcquisition( WRITE, READ );
}
@Test
public void acquiresWriteConnectionForWriteTxInWriteSession()
{
testConnectionAcquisition( WRITE, WRITE );
}
@Test
public void commitsReadTxWhenMarkedSuccessful()
{
testTxCommitOrRollback( READ, true );
}
@Test
public void commitsWriteTxWhenMarkedSuccessful()
{
testTxCommitOrRollback( WRITE, true );
}
@Test
public void rollsBackReadTxWhenMarkedSuccessful()
{
testTxCommitOrRollback( READ, false );
}
@Test
public void rollsBackWriteTxWhenMarkedSuccessful()
{
testTxCommitOrRollback( READ, true );
}
@Test
public void rollsBackReadTxWhenFunctionThrows()
{
testTxRollbackWhenThrows( READ );
}
@Test
public void rollsBackWriteTxWhenFunctionThrows()
{
testTxRollbackWhenThrows( WRITE );
}
@Test
public void readTxRetriedUntilSuccessWhenFunctionThrows()
{
testTxIsRetriedUntilSuccessWhenFunctionThrows( READ );
}
@Test
public void writeTxRetriedUntilSuccessWhenFunctionThrows()
{
testTxIsRetriedUntilSuccessWhenFunctionThrows( WRITE );
}
@Test
public void readTxRetriedUntilSuccessWhenTxCloseThrows()
{
testTxIsRetriedUntilSuccessWhenTxCloseThrows( READ );
}
@Test
public void writeTxRetriedUntilSuccessWhenTxCloseThrows()
{
testTxIsRetriedUntilSuccessWhenTxCloseThrows( WRITE );
}
@Test
public void readTxRetriedUntilFailureWhenFunctionThrows()
{
testTxIsRetriedUntilFailureWhenFunctionThrows( READ );
}
@Test
public void writeTxRetriedUntilFailureWhenFunctionThrows()
{
testTxIsRetriedUntilFailureWhenFunctionThrows( WRITE );
}
@Test
public void readTxRetriedUntilFailureWhenTxCloseThrows()
{
testTxIsRetriedUntilFailureWhenTxCloseThrows( READ );
}
@Test
public void writeTxRetriedUntilFailureWhenTxCloseThrows()
{
testTxIsRetriedUntilFailureWhenTxCloseThrows( WRITE );
}
private static void testConnectionAcquisition( AccessMode sessionMode, AccessMode transactionMode )
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( transactionMode ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, sessionMode );
TxWork work = new TxWork( 42 );
int result = executeTransaction( session, transactionMode, work );
verify( connectionProvider ).acquireConnection( transactionMode );
verifyBeginTx( connection, times( 1 ) );
verifyCommitTx( connection, times( 1 ) );
assertEquals( 42, result );
}
private static void testTxCommitOrRollback( AccessMode transactionMode, final boolean commit )
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( transactionMode ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, WRITE );
TransactionWork<Integer> work = new TransactionWork<Integer>()
{
@Override
public Integer execute( Transaction tx )
{
if ( commit )
{
tx.success();
}
else
{
tx.failure();
}
return 4242;
}
};
int result = executeTransaction( session, transactionMode, work );
verify( connectionProvider ).acquireConnection( transactionMode );
verifyBeginTx( connection, times( 1 ) );
if ( commit )
{
verifyCommitTx( connection, times( 1 ) );
verifyRollbackTx( connection, never() );
}
else
{
verifyRollbackTx( connection, times( 1 ) );
verifyCommitTx( connection, never() );
}
assertEquals( 4242, result );
}
private static void testTxRollbackWhenThrows( AccessMode transactionMode )
{
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( transactionMode ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, WRITE );
final RuntimeException error = new IllegalStateException( "Oh!" );
TransactionWork<Void> work = new TransactionWork<Void>()
{
@Override
public Void execute( Transaction tx )
{
throw error;
}
};
try
{
executeTransaction( session, transactionMode, work );
fail( "Exception expected" );
}
catch ( Exception e )
{
assertEquals( error, e );
}
verify( connectionProvider ).acquireConnection( transactionMode );
verifyBeginTx( connection, times( 1 ) );
verifyRollbackTx( connection, times( 1 ) );
}
private static void testTxIsRetriedUntilSuccessWhenFunctionThrows( AccessMode mode )
{
int failures = 12;
int retries = failures + 1;
RetryLogic retryLogic = new FixedRetryLogic( retries );
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( mode ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, retryLogic );
TxWork work = spy( new TxWork( 42, failures, new SessionExpiredException( "" ) ) );
int answer = executeTransaction( session, mode, work );
assertEquals( 42, answer );
verifyInvocationCount( work, failures + 1 );
verifyCommitTx( connection, times( 1 ) );
verifyRollbackTx( connection, times( failures ) );
}
private static void testTxIsRetriedUntilSuccessWhenTxCloseThrows( AccessMode mode )
{
int failures = 13;
int retries = failures + 1;
RetryLogic retryLogic = new FixedRetryLogic( retries );
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = connectionWithFailingCommit( failures );
when( connectionProvider.acquireConnection( mode ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, retryLogic );
TxWork work = spy( new TxWork( 43 ) );
int answer = executeTransaction( session, mode, work );
assertEquals( 43, answer );
verifyInvocationCount( work, failures + 1 );
verifyCommitTx( connection, times( retries ) );
verifyRollbackTx( connection, times( failures ) );
}
private static void testTxIsRetriedUntilFailureWhenFunctionThrows( AccessMode mode )
{
int failures = 14;
int retries = failures - 1;
RetryLogic retryLogic = new FixedRetryLogic( retries );
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = openConnectionMock();
when( connectionProvider.acquireConnection( mode ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, retryLogic );
TxWork work = spy( new TxWork( 42, failures, new SessionExpiredException( "Oh!" ) ) );
try
{
executeTransaction( session, mode, work );
fail( "Exception expected" );
}
catch ( Exception e )
{
assertThat( e, instanceOf( SessionExpiredException.class ) );
assertEquals( "Oh!", e.getMessage() );
verifyInvocationCount( work, failures );
verifyCommitTx( connection, never() );
verifyRollbackTx( connection, times( failures ) );
}
}
private static void testTxIsRetriedUntilFailureWhenTxCloseThrows( AccessMode mode )
{
int failures = 17;
int retries = failures - 1;
RetryLogic retryLogic = new FixedRetryLogic( retries );
ConnectionProvider connectionProvider = mock( ConnectionProvider.class );
PooledConnection connection = connectionWithFailingCommit( failures );
when( connectionProvider.acquireConnection( mode ) ).thenReturn( connection );
NetworkSession session = newSession( connectionProvider, retryLogic );
TxWork work = spy( new TxWork( 42 ) );
try
{
executeTransaction( session, mode, work );
fail( "Exception expected" );
}
catch ( Exception e )
{
assertThat( e, instanceOf( ServiceUnavailableException.class ) );
verifyInvocationCount( work, failures );
verifyCommitTx( connection, times( failures ) );
verifyRollbackTx( connection, times( failures ) );
}
}
private static <T> T executeTransaction( Session session, AccessMode mode, TransactionWork<T> work )
{
if ( mode == READ )
{
return session.readTransaction( work );
}
else if ( mode == WRITE )
{
return session.writeTransaction( work );
}
else
{
throw new IllegalArgumentException( "Unknown mode " + mode );
}
}
private static NetworkSession newSession( ConnectionProvider connectionProvider, AccessMode mode )
{
return newSession( connectionProvider, mode, null );
}
private static NetworkSession newSession( ConnectionProvider connectionProvider, RetryLogic retryLogic )
{
return newSession( connectionProvider, WRITE, retryLogic, null );
}
private static NetworkSession newSession( ConnectionProvider connectionProvider, AccessMode mode, String bookmark )
{
return newSession( connectionProvider, mode, new FixedRetryLogic( 0 ), bookmark );
}
private static NetworkSession newSession( ConnectionProvider connectionProvider, AccessMode mode,
RetryLogic retryLogic, String bookmark )
{
NetworkSession session = new NetworkSession( connectionProvider, mode, retryLogic, DEV_NULL_LOGGING );
session.setBookmark( bookmark );
return session;
}
private static PooledConnection openConnectionMock()
{
PooledConnection connection = mock( PooledConnection.class );
when( connection.isOpen() ).thenReturn( true );
return connection;
}
private static PooledConnection connectionWithFailingCommit( final int times )
{
PooledConnection connection = openConnectionMock();
doAnswer( new Answer<Void>()
{
int invoked;
@Override
public Void answer( InvocationOnMock invocation ) throws Throwable
{
if ( invoked++ < times )
{
throw new ServiceUnavailableException( "" );
}
return null;
}
} ).when( connection ).run( eq( "COMMIT" ), anyParams(), any( Collector.class ) );
return connection;
}
private static void verifyInvocationCount( TransactionWork<?> workSpy, int expectedInvocationCount )
{
verify( workSpy, times( expectedInvocationCount ) ).execute( any( Transaction.class ) );
}
private static void verifyBeginTx( PooledConnection connectionMock, VerificationMode mode )
{
verifyRun( connectionMock, "BEGIN", mode );
}
private static void verifyBeginTx( PooledConnection connectionMock, String bookmark )
{
verify( connectionMock ).run( "BEGIN", singletonMap( "bookmark", value( bookmark ) ), NO_OP );
}
private static void verifyCommitTx( PooledConnection connectionMock, VerificationMode mode )
{
verifyRun( connectionMock, "COMMIT", mode );
}
private static void verifyRollbackTx( PooledConnection connectionMock, VerificationMode mode )
{
verifyRun( connectionMock, "ROLLBACK", mode );
}
private static void verifyRun( PooledConnection connectionMock, String statement, VerificationMode mode )
{
verify( connectionMock, mode ).run( eq( statement ), anyParams(), any( Collector.class ) );
}
private static Map<String,Value> anyParams()
{
return anyMapOf( String.class, Value.class );
}
private static String getBookmark( Transaction tx )
{
return ((ExplicitTransaction) tx).bookmark();
}
private static void setBookmark( Transaction tx, String bookmark )
{
((ExplicitTransaction) tx).setBookmark( bookmark );
}
private static class TxWork implements TransactionWork<Integer>
{
final int result;
final int timesToThrow;
final Supplier<RuntimeException> errorSupplier;
int invoked;
@SuppressWarnings( "unchecked" )
TxWork( int result )
{
this( result, 0, (Supplier) null );
}
TxWork( int result, int timesToThrow, final RuntimeException error )
{
this.result = result;
this.timesToThrow = timesToThrow;
this.errorSupplier = new Supplier<RuntimeException>()
{
@Override
public RuntimeException get()
{
return error;
}
};
}
TxWork( int result, int timesToThrow, Supplier<RuntimeException> errorSupplier )
{
this.result = result;
this.timesToThrow = timesToThrow;
this.errorSupplier = errorSupplier;
}
@Override
public Integer execute( Transaction tx )
{
if ( timesToThrow > 0 && invoked++ < timesToThrow )
{
throw errorSupplier.get();
}
tx.success();
return result;
}
}
}