/*
* Copyright (c) 2002-2017 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neo4j.driver.internal.cluster;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.neo4j.driver.internal.net.BoltServerAddress;
import org.neo4j.driver.internal.net.pooling.PoolSettings;
import org.neo4j.driver.internal.net.pooling.SocketConnectionPool;
import org.neo4j.driver.internal.spi.Connection;
import org.neo4j.driver.internal.spi.ConnectionPool;
import org.neo4j.driver.internal.spi.Connector;
import org.neo4j.driver.internal.spi.PooledConnection;
import org.neo4j.driver.internal.util.Clock;
import org.neo4j.driver.v1.exceptions.ClientException;
import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;
import org.neo4j.driver.v1.exceptions.SessionExpiredException;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonMap;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.neo4j.driver.internal.logging.DevNullLogger.DEV_NULL_LOGGER;
import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING;
import static org.neo4j.driver.internal.net.pooling.PoolSettings.DEFAULT_MAX_IDLE_CONNECTION_POOL_SIZE;
import static org.neo4j.driver.internal.net.pooling.PoolSettings.NO_IDLE_CONNECTION_TEST;
import static org.neo4j.driver.internal.spi.Collector.NO_OP;
import static org.neo4j.driver.internal.util.Matchers.containsReader;
import static org.neo4j.driver.internal.util.Matchers.containsRouter;
import static org.neo4j.driver.internal.util.Matchers.containsWriter;
import static org.neo4j.driver.v1.AccessMode.READ;
import static org.neo4j.driver.v1.AccessMode.WRITE;
import static org.neo4j.driver.v1.Values.value;
@RunWith( Parameterized.class )
public class RoutingPooledConnectionErrorHandlingTest
{
private static final BoltServerAddress ADDRESS1 = new BoltServerAddress( "server-1", 26000 );
private static final BoltServerAddress ADDRESS2 = new BoltServerAddress( "server-2", 27000 );
private static final BoltServerAddress ADDRESS3 = new BoltServerAddress( "server-3", 28000 );
@Parameter
public ConnectionMethod method;
@Parameters( name = "{0}" )
public static List<ConnectionMethod> methods()
{
return asList(
new Init(),
new Run(),
new DiscardAll(),
new PullAll(),
new Reset(),
new ResetAsync(),
new AckFailure(),
new Sync(),
new Flush(),
new ReceiveOne() );
}
@Test
public void shouldHandleServiceUnavailableException()
{
ServiceUnavailableException serviceUnavailable = new ServiceUnavailableException( "Oh!" );
Connector connector = newConnectorWithThrowingConnections( serviceUnavailable );
ClusterComposition clusterComposition = newClusterComposition( ADDRESS1, ADDRESS2, ADDRESS3 );
RoutingTable routingTable = newRoutingTable( clusterComposition );
ConnectionPool connectionPool = newConnectionPool( connector, ADDRESS1, ADDRESS2, ADDRESS3 );
LoadBalancer loadBalancer = newLoadBalancer( clusterComposition, routingTable, connectionPool );
Connection readConnection = loadBalancer.acquireConnection( READ );
verifyServiceUnavailableHandling( readConnection, routingTable, connectionPool );
Connection writeConnection = loadBalancer.acquireConnection( WRITE );
verifyServiceUnavailableHandling( writeConnection, routingTable, connectionPool );
assertThat( routingTable, containsRouter( ADDRESS3 ) );
assertTrue( connectionPool.hasAddress( ADDRESS3 ) );
}
@Test
public void shouldHandleFailureToWriteWithWriteConnection()
{
testHandleFailureToWriteWithWriteConnection( new ClientException( "Neo.ClientError.Cluster.NotALeader", "" ) );
testHandleFailureToWriteWithWriteConnection(
new ClientException( "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase", "" ) );
}
@Test
public void shouldHandleFailureToWrite()
{
testHandleFailureToWrite( new ClientException( "Neo.ClientError.Cluster.NotALeader", "" ) );
testHandleFailureToWrite( new ClientException( "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase", "" ) );
}
@Test
public void shouldPropagateThrowable()
{
testThrowablePropagation( new RuntimeException( "Random error" ) );
}
@Test
public void shouldPropagateClientExceptionWithoutErrorCode()
{
testThrowablePropagation( new ClientException( null, "Message" ) );
}
private void testHandleFailureToWriteWithWriteConnection( ClientException error )
{
Connector connector = newConnectorWithThrowingConnections( error );
ClusterComposition clusterComposition = newClusterComposition( ADDRESS1, ADDRESS2, ADDRESS3 );
RoutingTable routingTable = newRoutingTable( clusterComposition );
ConnectionPool connectionPool = newConnectionPool( connector, ADDRESS1, ADDRESS2, ADDRESS3 );
LoadBalancer loadBalancer = newLoadBalancer( clusterComposition, routingTable, connectionPool );
Connection readConnection = loadBalancer.acquireConnection( READ );
try
{
method.invoke( readConnection );
fail( "Exception expected" );
}
catch ( Exception e )
{
assertThat( e, instanceOf( ClientException.class ) );
BoltServerAddress address = readConnection.boltServerAddress();
assertThat( routingTable, containsRouter( address ) );
assertThat( routingTable, containsReader( address ) );
assertThat( routingTable, containsWriter( address ) );
assertTrue( connectionPool.hasAddress( address ) );
}
assertThat( routingTable, containsRouter( ADDRESS3 ) );
assertTrue( connectionPool.hasAddress( ADDRESS3 ) );
}
private void testHandleFailureToWrite( ClientException error )
{
Connector connector = newConnectorWithThrowingConnections( error );
ClusterComposition clusterComposition = newClusterComposition( ADDRESS1, ADDRESS2, ADDRESS3 );
RoutingTable routingTable = newRoutingTable( clusterComposition );
ConnectionPool connectionPool = newConnectionPool( connector, ADDRESS1, ADDRESS2, ADDRESS3 );
LoadBalancer loadBalancer = newLoadBalancer( clusterComposition, routingTable, connectionPool );
Connection readConnection = loadBalancer.acquireConnection( WRITE );
try
{
method.invoke( readConnection );
fail( "Exception expected" );
}
catch ( Exception e )
{
assertThat( e, instanceOf( SessionExpiredException.class ) );
BoltServerAddress address = readConnection.boltServerAddress();
assertThat( routingTable, containsRouter( address ) );
assertThat( routingTable, containsReader( address ) );
assertThat( routingTable, not( containsWriter( address ) ) );
assertTrue( connectionPool.hasAddress( address ) );
}
assertThat( routingTable, containsRouter( ADDRESS3 ) );
assertTrue( connectionPool.hasAddress( ADDRESS3 ) );
}
private void testThrowablePropagation( Throwable error )
{
Connector connector = newConnectorWithThrowingConnections( error );
ClusterComposition clusterComposition = newClusterComposition( ADDRESS1, ADDRESS2, ADDRESS3 );
RoutingTable routingTable = newRoutingTable( clusterComposition );
ConnectionPool connectionPool = newConnectionPool( connector, ADDRESS1, ADDRESS2, ADDRESS3 );
LoadBalancer loadBalancer = newLoadBalancer( clusterComposition, routingTable, connectionPool );
Connection readConnection = loadBalancer.acquireConnection( READ );
verifyThrowablePropagation( readConnection, routingTable, connectionPool, error.getClass() );
Connection writeConnection = loadBalancer.acquireConnection( WRITE );
verifyThrowablePropagation( writeConnection, routingTable, connectionPool, error.getClass() );
assertThat( routingTable, containsRouter( ADDRESS3 ) );
assertTrue( connectionPool.hasAddress( ADDRESS3 ) );
}
private void verifyServiceUnavailableHandling( Connection connection, RoutingTable routingTable,
ConnectionPool connectionPool )
{
try
{
method.invoke( connection );
fail( "Exception expected" );
}
catch ( Exception e )
{
assertThat( e, instanceOf( SessionExpiredException.class ) );
assertThat( e.getCause(), instanceOf( ServiceUnavailableException.class ) );
BoltServerAddress address = connection.boltServerAddress();
assertThat( routingTable, not( containsRouter( address ) ) );
assertThat( routingTable, not( containsReader( address ) ) );
assertThat( routingTable, not( containsWriter( address ) ) );
assertFalse( connectionPool.hasAddress( address ) );
}
}
private <T extends Throwable> void verifyThrowablePropagation( Connection connection, RoutingTable routingTable,
ConnectionPool connectionPool, Class<T> expectedClass )
{
try
{
method.invoke( connection );
fail( "Exception expected" );
}
catch ( Exception e )
{
assertThat( e, instanceOf( expectedClass ) );
BoltServerAddress address = connection.boltServerAddress();
assertThat( routingTable, containsRouter( address ) );
assertThat( routingTable, containsReader( address ) );
assertThat( routingTable, containsWriter( address ) );
assertTrue( connectionPool.hasAddress( address ) );
}
}
private Connector newConnectorWithThrowingConnections( final Throwable error )
{
Connector connector = mock( Connector.class );
when( connector.connect( any( BoltServerAddress.class ) ) ).thenAnswer( new Answer<Connection>()
{
@Override
public Connection answer( InvocationOnMock invocation ) throws Throwable
{
BoltServerAddress address = invocation.getArgumentAt( 0, BoltServerAddress.class );
Connection connection = newConnectionMock( address );
method.invoke( doThrow( error ).doNothing().when( connection ) );
return connection;
}
} );
return connector;
}
private static Connection newConnectionMock( BoltServerAddress address )
{
Connection connection = mock( Connection.class );
when( connection.boltServerAddress() ).thenReturn( address );
return connection;
}
private static ClusterComposition newClusterComposition( BoltServerAddress... addresses )
{
return new ClusterComposition(
Long.MAX_VALUE,
new HashSet<>( asList( addresses ) ),
new HashSet<>( asList( addresses ) ),
new HashSet<>( asList( addresses ) ) );
}
private static RoutingTable newRoutingTable( ClusterComposition clusterComposition )
{
RoutingTable routingTable = new ClusterRoutingTable( Clock.SYSTEM );
routingTable.update( clusterComposition );
return routingTable;
}
private static ConnectionPool newConnectionPool( Connector connector, BoltServerAddress... addresses )
{
int maxIdleConnections = DEFAULT_MAX_IDLE_CONNECTION_POOL_SIZE;
PoolSettings settings = new PoolSettings( maxIdleConnections, NO_IDLE_CONNECTION_TEST );
SocketConnectionPool pool = new SocketConnectionPool( settings, connector, Clock.SYSTEM, DEV_NULL_LOGGING );
// force pool to create and memorize some connections
for ( BoltServerAddress address : addresses )
{
List<PooledConnection> connections = new ArrayList<>();
for ( int i = 0; i < maxIdleConnections; i++ )
{
connections.add( pool.acquire( address ) );
}
for ( PooledConnection connection : connections )
{
connection.close();
}
}
return pool;
}
private static LoadBalancer newLoadBalancer( ClusterComposition clusterComposition, RoutingTable routingTable,
ConnectionPool connectionPool )
{
Rediscovery rediscovery = mock( Rediscovery.class );
when( rediscovery.lookupClusterComposition( routingTable, connectionPool ) ).thenReturn( clusterComposition );
return new LoadBalancer( connectionPool, routingTable, rediscovery, DEV_NULL_LOGGER );
}
private interface ConnectionMethod
{
void invoke( Connection connection );
}
private static class Init implements ConnectionMethod
{
@Override
public void invoke( Connection connection )
{
connection.init( "JavaDriver", singletonMap( "Key", value( "Value" ) ) );
}
}
private static class Run implements ConnectionMethod
{
@Override
public void invoke( Connection connection )
{
connection.run( "CREATE (n:Node {name: {value}})", singletonMap( "value", value( "A" ) ), NO_OP );
}
}
private static class DiscardAll implements ConnectionMethod
{
@Override
public void invoke( Connection connection )
{
connection.discardAll( NO_OP );
}
}
private static class PullAll implements ConnectionMethod
{
@Override
public void invoke( Connection connection )
{
connection.pullAll( NO_OP );
}
}
private static class Reset implements ConnectionMethod
{
@Override
public void invoke( Connection connection )
{
connection.reset();
}
}
private static class ResetAsync implements ConnectionMethod
{
@Override
public void invoke( Connection connection )
{
connection.resetAsync();
}
}
private static class AckFailure implements ConnectionMethod
{
@Override
public void invoke( Connection connection )
{
connection.ackFailure();
}
}
private static class Sync implements ConnectionMethod
{
@Override
public void invoke( Connection connection )
{
connection.sync();
}
}
private static class Flush implements ConnectionMethod
{
@Override
public void invoke( Connection connection )
{
connection.flush();
}
}
private static class ReceiveOne implements ConnectionMethod
{
@Override
public void invoke( Connection connection )
{
connection.receiveOne();
}
}
}