/*
* Copyright (c) 2002-2017 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neo4j.driver.internal;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.internal.stubbing.answers.ThrowsException;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.File;
import java.util.Collections;
import java.util.Map;
import org.neo4j.driver.internal.cluster.LoadBalancer;
import org.neo4j.driver.internal.cluster.RoutingSettings;
import org.neo4j.driver.internal.net.BoltServerAddress;
import org.neo4j.driver.internal.retry.FixedRetryLogic;
import org.neo4j.driver.internal.spi.Collector;
import org.neo4j.driver.internal.spi.ConnectionPool;
import org.neo4j.driver.internal.spi.ConnectionProvider;
import org.neo4j.driver.internal.spi.PooledConnection;
import org.neo4j.driver.internal.summary.InternalServerInfo;
import org.neo4j.driver.internal.util.FakeClock;
import org.neo4j.driver.v1.AccessMode;
import org.neo4j.driver.v1.Config;
import org.neo4j.driver.v1.Driver;
import org.neo4j.driver.v1.EventLogger;
import org.neo4j.driver.v1.GraphDatabase;
import org.neo4j.driver.v1.Logging;
import org.neo4j.driver.v1.Session;
import org.neo4j.driver.v1.Value;
import org.neo4j.driver.v1.exceptions.ClientException;
import org.neo4j.driver.v1.exceptions.ProtocolException;
import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;
import static java.util.Arrays.asList;
import static junit.framework.TestCase.fail;
import static org.hamcrest.Matchers.containsString;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.neo4j.driver.internal.cluster.ClusterCompositionProviderTest.serverInfo;
import static org.neo4j.driver.internal.security.SecurityPlan.insecure;
import static org.neo4j.driver.v1.Values.value;
public class RoutingDriverTest
{
@Rule
public ExpectedException exception = ExpectedException.none();
private static final BoltServerAddress SEED = new BoltServerAddress( "localhost", 7687 );
private static final String GET_SERVERS = "CALL dbms.cluster.routing.getServers";
private final EventHandler events = new EventHandler();
private final FakeClock clock = new FakeClock( events, true );
private final Logging logging = EventLogger.provider( events, EventLogger.Level.TRACE );
@Test
public void shouldDiscoveryOnInitialization()
{
// Given
ConnectionPool pool = poolWithServers(
10,
serverInfo( "ROUTE", "localhost:1111" ),
serverInfo( "READ", "localhost:2222" ),
serverInfo( "WRITE", "localhost:3333" ) );
// When
driverWithPool( pool );
// Then
verify( pool ).acquire( SEED );
}
@Test
public void shouldRediscoveryIfNoWritersProvided()
{
// Given
Driver driver = driverWithPool( pool(
withServers( 10, serverInfo( "ROUTE", "localhost:1111" ),
serverInfo( "WRITE" ),
serverInfo( "READ", "localhost:5555" ) ),
withServers( 10, serverInfo( "ROUTE", "localhost:1112" ),
serverInfo( "READ", "localhost:2222" ),
serverInfo( "WRITE", "localhost:3333" ) ) ) );
// When
NetworkSessionWithAddress writing = (NetworkSessionWithAddress) driver.session( AccessMode.WRITE );
// Then
assertEquals( boltAddress( "localhost", 3333 ), writing.address );
}
@Test
public void shouldNotRediscoveryOnSessionAcquisitionIfNotNecessary()
{
// Given
Driver driver = driverWithPool( pool(
withServers( 10, serverInfo( "ROUTE", "localhost:1111", "localhost:1112", "localhost:1113" ),
serverInfo( "READ", "localhost:2222" ),
serverInfo( "WRITE", "localhost:3333" ) ),
withServers( 10, serverInfo( "ROUTE", "localhost:5555" ),
serverInfo( "READ", "localhost:5555" ),
serverInfo( "WRITE", "localhost:5555" ) ) ) );
// When
NetworkSessionWithAddress writing = (NetworkSessionWithAddress) driver.session( AccessMode.WRITE );
NetworkSessionWithAddress reading = (NetworkSessionWithAddress) driver.session( AccessMode.READ );
// Then
assertEquals( boltAddress( "localhost", 3333 ), writing.address );
assertEquals( boltAddress( "localhost", 2222 ), reading.address );
}
@Test
public void shouldFailIfNoRouting()
{
// Given
ConnectionPool pool = pool( new ThrowsException( new ClientException(
"Neo.ClientError.Procedure.ProcedureNotFound", "Procedure not found" ) ) );
// When
try
{
driverWithPool( pool );
}
// Then
catch ( ServiceUnavailableException e )
{
assertThat( e.getMessage(),
containsString( "Failed to run 'CALL dbms.cluster.routing.getServers {}' on server." ) );
}
}
@Test
public void shouldFailIfNoRoutersProvided()
{
// Given
ConnectionPool pool = poolWithServers(
10,
serverInfo( "ROUTE" ),
serverInfo( "READ", "localhost:1111" ),
serverInfo( "WRITE", "localhost:1111" ) );
// When
try
{
driverWithPool( pool );
}
// Then
catch ( ProtocolException e )
{
assertThat( e.getMessage(), containsString( "no router or reader found in response" ) );
}
}
@Test
public void shouldFailIfNoReaderProvided()
{
// Given
ConnectionPool pool = poolWithServers(
10,
serverInfo( "READ" ),
serverInfo( "ROUTE", "localhost:1111" ),
serverInfo( "WRITE", "localhost:1111" ) );
// When
try
{
driverWithPool( pool );
}
// Then
catch ( ProtocolException e )
{
assertThat( e.getMessage(), containsString( "no router or reader found in response" ) );
}
}
@Test
public void shouldForgetServersOnRediscovery()
{
// Given
ConnectionPool pool = pool(
withServers( 10, serverInfo( "ROUTE", "localhost:1111" ),
serverInfo( "READ", "localhost:5555" ),
serverInfo( "WRITE" ) ),
withServers( 10, serverInfo( "ROUTE", "localhost:1112" ),
serverInfo( "READ", "localhost:2222" ),
serverInfo( "WRITE", "localhost:3333" ) ) );
Driver driver = driverWithPool( pool );
// When
NetworkSessionWithAddress write1 = (NetworkSessionWithAddress) driver.session( AccessMode.WRITE );
NetworkSessionWithAddress write2 = (NetworkSessionWithAddress) driver.session( AccessMode.WRITE );
// Then
assertEquals( boltAddress( "localhost", 3333 ), write1.address );
assertEquals( boltAddress( "localhost", 3333 ), write2.address );
}
@Test
public void shouldRediscoverOnTimeout()
{
// Given
Driver driver = driverWithPool( pool(
withServers( 10, serverInfo( "ROUTE", "localhost:1111", "localhost:1112", "localhost:1113" ),
serverInfo( "READ", "localhost:2222" ),
serverInfo( "WRITE", "localhost:3333" ) ),
withServers( 60, serverInfo( "ROUTE", "localhost:5555", "localhost:6666" ),
serverInfo( "READ", "localhost:7777" ),
serverInfo( "WRITE", "localhost:8888" ) ) ) );
clock.progress( 11_000 );
// When
NetworkSessionWithAddress writing = (NetworkSessionWithAddress) driver.session( AccessMode.WRITE );
NetworkSessionWithAddress reading = (NetworkSessionWithAddress) driver.session( AccessMode.READ );
// Then
assertEquals( boltAddress( "localhost", 8888 ), writing.address );
assertEquals( boltAddress( "localhost", 7777 ), reading.address );
}
@Test
public void shouldNotRediscoverWhenNoTimeout()
{
// Given
Driver driver = driverWithPool( pool(
withServers( 10, serverInfo( "ROUTE", "localhost:1111", "localhost:1112", "localhost:1113" ),
serverInfo( "READ", "localhost:2222" ),
serverInfo( "WRITE", "localhost:3333" ) ),
withServers( 10, serverInfo( "ROUTE", "localhost:5555" ),
serverInfo( "READ", "localhost:5555" ),
serverInfo( "WRITE", "localhost:5555" ) ) ) );
clock.progress( 9900 );
// When
NetworkSessionWithAddress writer = (NetworkSessionWithAddress) driver.session( AccessMode.WRITE );
NetworkSessionWithAddress reader = (NetworkSessionWithAddress) driver.session( AccessMode.READ );
// Then
assertEquals( boltAddress( "localhost", 2222 ), reader.address );
assertEquals( boltAddress( "localhost", 3333 ), writer.address );
}
@Test
public void shouldRoundRobinAmongReadServers()
{
// Given
Driver driver = driverWithServers( 60,
serverInfo( "ROUTE", "localhost:1111", "localhost:1112" ),
serverInfo( "READ", "localhost:2222", "localhost:2223", "localhost:2224" ),
serverInfo( "WRITE", "localhost:3333" ) );
// When
NetworkSessionWithAddress read1 = (NetworkSessionWithAddress) driver.session( AccessMode.READ );
NetworkSessionWithAddress read2 = (NetworkSessionWithAddress) driver.session( AccessMode.READ );
NetworkSessionWithAddress read3 = (NetworkSessionWithAddress) driver.session( AccessMode.READ );
NetworkSessionWithAddress read4 = (NetworkSessionWithAddress) driver.session( AccessMode.READ );
NetworkSessionWithAddress read5 = (NetworkSessionWithAddress) driver.session( AccessMode.READ );
NetworkSessionWithAddress read6 = (NetworkSessionWithAddress) driver.session( AccessMode.READ );
// Then
assertEquals( read1.address, read4.address );
assertEquals( read2.address, read5.address );
assertEquals( read3.address, read6.address );
assertNotEquals( read1.address, read2.address );
assertNotEquals( read2.address, read3.address );
assertNotEquals( read3.address, read1.address );
}
@Test
public void shouldRoundRobinAmongWriteServers()
{
// Given
Driver driver = driverWithServers( 60, serverInfo( "ROUTE", "localhost:1111", "localhost:1112" ),
serverInfo( "READ", "localhost:3333" ),
serverInfo( "WRITE", "localhost:2222", "localhost:2223", "localhost:2224" ) );
// When
NetworkSessionWithAddress write1 = (NetworkSessionWithAddress) driver.session( AccessMode.WRITE );
NetworkSessionWithAddress write2 = (NetworkSessionWithAddress) driver.session( AccessMode.WRITE );
NetworkSessionWithAddress write3 = (NetworkSessionWithAddress) driver.session( AccessMode.WRITE );
NetworkSessionWithAddress write4 = (NetworkSessionWithAddress) driver.session( AccessMode.WRITE );
NetworkSessionWithAddress write5 = (NetworkSessionWithAddress) driver.session( AccessMode.WRITE );
NetworkSessionWithAddress write6 = (NetworkSessionWithAddress) driver.session( AccessMode.WRITE );
// Then
assertEquals( write1.address, write4.address );
assertEquals( write2.address, write5.address );
assertEquals( write3.address, write6.address );
assertNotEquals( write1.address, write2.address );
assertNotEquals( write2.address, write3.address );
assertNotEquals( write3.address, write1.address );
}
@SuppressWarnings( "deprecation" )
@Test
public void testTrustOnFirstUseNotCompatibleWithRoutingDriver()
{
// Given
final Config tofuConfig = Config.build()
.withEncryptionLevel( Config.EncryptionLevel.REQUIRED )
.withTrustStrategy( Config.TrustStrategy.trustOnFirstUse( new File( "foo" ) ) ).toConfig();
try
{
// When
GraphDatabase.driver( "bolt+routing://127.0.0.1:7687", tofuConfig );
fail();
}
catch ( IllegalArgumentException e )
{
// Then we should end up here
}
}
@SafeVarargs
private final Driver driverWithServers( long ttl, Map<String,Object>... serverInfo )
{
return driverWithPool( poolWithServers( ttl, serverInfo ) );
}
private Driver driverWithPool( ConnectionPool pool )
{
RoutingSettings settings = new RoutingSettings( 10, 5_000, null );
ConnectionProvider connectionProvider = new LoadBalancer( SEED, settings, pool, clock, logging );
Config config = Config.build().withLogging( logging ).toConfig();
SessionFactory sessionFactory = new NetworkSessionWithAddressFactory( connectionProvider, config );
return new InternalDriver( insecure(), sessionFactory, logging );
}
@SafeVarargs
private final ConnectionPool poolWithServers( long ttl, Map<String,Object>... serverInfo )
{
return pool( withServers( ttl, serverInfo ) );
}
@SafeVarargs
private static Answer withServers( long ttl, Map<String,Object>... serverInfo )
{
return withServerList( new Value[]{value( ttl ), value( asList( serverInfo ) )} );
}
private BoltServerAddress boltAddress( String host, int port )
{
return new BoltServerAddress( host, port );
}
private ConnectionPool pool( final Answer toGetServers, final Answer... furtherGetServers )
{
ConnectionPool pool = mock( ConnectionPool.class );
when( pool.acquire( any( BoltServerAddress.class ) ) ).thenAnswer( new Answer<PooledConnection>()
{
int answer;
@Override
public PooledConnection answer( InvocationOnMock invocationOnMock ) throws Throwable
{
BoltServerAddress address = invocationOnMock.getArgumentAt( 0, BoltServerAddress.class );
PooledConnection connection = mock( PooledConnection.class );
when( connection.isOpen() ).thenReturn( true );
when( connection.boltServerAddress() ).thenReturn( address );
when( connection.server() ).thenReturn( new InternalServerInfo( address, "Neo4j/3.1.0" ) );
doAnswer( withKeys( "ttl", "servers" ) ).when( connection ).run(
eq( GET_SERVERS ),
eq( Collections.<String,Value>emptyMap() ),
any( Collector.class ) );
if ( answer > furtherGetServers.length )
{
answer = furtherGetServers.length;
}
int offset = answer++;
doAnswer( offset == 0 ? toGetServers : furtherGetServers[offset - 1] )
.when( connection ).pullAll( any( Collector.class ) );
return connection;
}
} );
return pool;
}
private static CollectorAnswer withKeys( final String... keys )
{
return new CollectorAnswer()
{
@Override
void collect( Collector collector )
{
collector.keys( keys );
}
};
}
private static CollectorAnswer withServerList( final Value[]... records )
{
return new CollectorAnswer()
{
@Override
void collect( Collector collector )
{
for ( Value[] fields : records )
{
collector.record( fields );
}
}
};
}
private static class NetworkSessionWithAddressFactory extends SessionFactoryImpl
{
NetworkSessionWithAddressFactory( ConnectionProvider connectionProvider, Config config )
{
super( connectionProvider, new FixedRetryLogic( 0 ), config );
}
@Override
public Session newInstance( AccessMode mode, String bookmark )
{
NetworkSessionWithAddress session = new NetworkSessionWithAddress( connectionProvider, mode, logging );
session.setBookmark( bookmark );
return session;
}
}
private static class NetworkSessionWithAddress extends NetworkSession
{
final BoltServerAddress address;
NetworkSessionWithAddress( ConnectionProvider connectionProvider, AccessMode mode, Logging logging )
{
super( connectionProvider, mode, new FixedRetryLogic( 0 ), logging );
try ( PooledConnection connection = connectionProvider.acquireConnection( mode ) )
{
this.address = connection.boltServerAddress();
}
}
}
private static abstract class CollectorAnswer implements Answer
{
abstract void collect( Collector collector );
@Override
public final Object answer( InvocationOnMock invocation ) throws Throwable
{
Collector collector = collector( invocation );
collect( collector );
collector.done();
return null;
}
private Collector collector( InvocationOnMock invocation )
{
switch ( invocation.getMethod().getName() )
{
case "pullAll":
return invocation.getArgumentAt( 0, Collector.class );
case "run":
return invocation.getArgumentAt( 2, Collector.class );
default:
throw new UnsupportedOperationException( invocation.getMethod().getName() );
}
}
}
}