/* * Copyright (c) 2002-2017 "Neo Technology," * Network Engine for Objects in Lund AB [http://neotechnology.com] * * This file is part of Neo4j. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.driver.internal.cluster; import org.junit.Test; import org.mockito.InOrder; import java.util.Collections; import java.util.HashSet; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import org.neo4j.driver.internal.ExplicitTransaction; import org.neo4j.driver.internal.NetworkSession; import org.neo4j.driver.internal.SessionResourcesHandler; import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.retry.ExponentialBackoffRetryLogic; import org.neo4j.driver.internal.retry.RetryLogic; import org.neo4j.driver.internal.retry.RetrySettings; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionPool; import org.neo4j.driver.internal.spi.PooledConnection; import org.neo4j.driver.internal.util.SleeplessClock; import org.neo4j.driver.v1.AccessMode; import org.neo4j.driver.v1.Session; import org.neo4j.driver.v1.Transaction; import org.neo4j.driver.v1.Value; import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; import org.neo4j.driver.v1.exceptions.SessionExpiredException; import static java.util.Collections.singleton; import static java.util.Collections.singletonList; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.startsWith; import static org.hamcrest.core.IsEqual.equalTo; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Mockito.RETURNS_MOCKS; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.neo4j.driver.internal.logging.DevNullLogger.DEV_NULL_LOGGER; import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING; import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT; import static org.neo4j.driver.v1.AccessMode.READ; import static org.neo4j.driver.v1.AccessMode.WRITE; public class LoadBalancerTest { @Test public void ensureRoutingShouldUpdateRoutingTableAndPurgeConnectionPoolWhenStale() throws Exception { // given ConnectionPool conns = mock( ConnectionPool.class ); RoutingTable routingTable = mock( RoutingTable.class ); Rediscovery rediscovery = mock( Rediscovery.class ); Set<BoltServerAddress> set = singleton( new BoltServerAddress( "abc", 12 ) ); when( routingTable.update( any( ClusterComposition.class ) ) ).thenReturn( set ); // when LoadBalancer balancer = new LoadBalancer( conns, routingTable, rediscovery, DEV_NULL_LOGGER ); // then assertNotNull( balancer ); InOrder inOrder = inOrder( rediscovery, routingTable, conns ); inOrder.verify( rediscovery ).lookupClusterComposition( routingTable, conns ); inOrder.verify( routingTable ).update( any( ClusterComposition.class ) ); inOrder.verify( conns ).purge( new BoltServerAddress( "abc", 12 ) ); } @Test public void shouldRefreshRoutingTableOnInitialization() throws Exception { // given & when final AtomicInteger refreshRoutingTableCounter = new AtomicInteger( 0 ); LoadBalancer balancer = new LoadBalancer( mock( ConnectionPool.class ), mock( RoutingTable.class ), mock( Rediscovery.class ), DEV_NULL_LOGGER ) { @Override synchronized void refreshRoutingTable() { refreshRoutingTableCounter.incrementAndGet(); } }; // then assertNotNull( balancer ); assertThat( refreshRoutingTableCounter.get(), equalTo( 1 ) ); } @Test public void shouldEnsureRoutingWhenAcquireConn() throws Exception { // given PooledConnection writerConn = mock( PooledConnection.class ); PooledConnection readConn = mock( PooledConnection.class ); LoadBalancer balancer = setupLoadBalancer( writerConn, readConn ); LoadBalancer spy = spy( balancer ); // when Connection connection = spy.acquireConnection( READ ); connection.init( "Test", Collections.<String,Value>emptyMap() ); // then verify( spy ).ensureRouting( READ ); verify( readConn ).init( "Test", Collections.<String,Value>emptyMap() ); } @Test public void shouldAcquireReaderOrWriterConn() throws Exception { PooledConnection writerConn = mock( PooledConnection.class ); PooledConnection readConn = mock( PooledConnection.class ); LoadBalancer balancer = setupLoadBalancer( writerConn, readConn ); Connection acquiredReadConn = balancer.acquireConnection( READ ); acquiredReadConn.init( "TestRead", Collections.<String,Value>emptyMap() ); verify( readConn ).init( "TestRead", Collections.<String,Value>emptyMap() ); Connection acquiredWriteConn = balancer.acquireConnection( WRITE ); acquiredWriteConn.init( "TestWrite", Collections.<String,Value>emptyMap() ); verify( writerConn ).init( "TestWrite", Collections.<String,Value>emptyMap() ); } @Test public void shouldForgetAddressAndItsConnectionsOnServiceUnavailableWhileClosingTx() { RoutingTable routingTable = mock( RoutingTable.class ); ConnectionPool connectionPool = mock( ConnectionPool.class ); Rediscovery rediscovery = mock( Rediscovery.class ); LoadBalancer loadBalancer = new LoadBalancer( connectionPool, routingTable, rediscovery, DEV_NULL_LOGGER ); BoltServerAddress address = new BoltServerAddress( "host", 42 ); PooledConnection connection = newConnectionWithFailingSync( address ); Connection routingConnection = new RoutingPooledConnection( connection, loadBalancer, AccessMode.WRITE ); Transaction tx = new ExplicitTransaction( routingConnection, mock( SessionResourcesHandler.class ) ); try { tx.close(); fail( "Exception expected" ); } catch ( Exception e ) { assertThat( e, instanceOf( SessionExpiredException.class ) ); assertThat( e.getCause(), instanceOf( ServiceUnavailableException.class ) ); } verify( routingTable ).forget( address ); verify( connectionPool ).purge( address ); } @Test public void shouldForgetAddressAndItsConnectionsOnServiceUnavailableWhileClosingSession() { RoutingTable routingTable = mock( RoutingTable.class, RETURNS_MOCKS ); ConnectionPool connectionPool = mock( ConnectionPool.class ); BoltServerAddress address = new BoltServerAddress( "host", 42 ); PooledConnection connectionWithFailingSync = newConnectionWithFailingSync( address ); when( connectionPool.acquire( any( BoltServerAddress.class ) ) ).thenReturn( connectionWithFailingSync ); Rediscovery rediscovery = mock( Rediscovery.class ); LoadBalancer loadBalancer = new LoadBalancer( connectionPool, routingTable, rediscovery, DEV_NULL_LOGGER ); Session session = newSession( loadBalancer ); // begin transaction to make session obtain a connection session.beginTransaction(); session.close(); verify( routingTable ).forget( address ); verify( connectionPool ).purge( address ); } @Test public void shouldRediscoverOnReadWhenRoutingTableIsStaleForReads() { testRediscoveryWhenStale( READ ); } @Test public void shouldRediscoverOnWriteWhenRoutingTableIsStaleForWrites() { testRediscoveryWhenStale( WRITE ); } @Test public void shouldNotRediscoverOnReadWhenRoutingTableIsStaleForWritesButNotReads() { testNoRediscoveryWhenNotStale( WRITE, READ ); } @Test public void shouldNotRediscoverOnWriteWhenRoutingTableIsStaleForReadsButNotWrites() { testNoRediscoveryWhenNotStale( READ, WRITE ); } @Test public void shouldThrowWhenRediscoveryReturnsNoSuitableServers() { ConnectionPool connections = mock( ConnectionPool.class ); RoutingTable routingTable = mock( RoutingTable.class ); when( routingTable.isStaleFor( any( AccessMode.class ) ) ).thenReturn( true ); Rediscovery rediscovery = mock( Rediscovery.class ); when( routingTable.readers() ).thenReturn( new RoundRobinAddressSet() ); when( routingTable.writers() ).thenReturn( new RoundRobinAddressSet() ); LoadBalancer loadBalancer = new LoadBalancer( connections, routingTable, rediscovery, DEV_NULL_LOGGER ); try { loadBalancer.acquireConnection( READ ); fail( "Exception expected" ); } catch ( Exception e ) { assertThat( e, instanceOf( SessionExpiredException.class ) ); assertThat( e.getMessage(), startsWith( "Failed to obtain connection towards READ server" ) ); } try { loadBalancer.acquireConnection( WRITE ); fail( "Exception expected" ); } catch ( Exception e ) { assertThat( e, instanceOf( SessionExpiredException.class ) ); assertThat( e.getMessage(), startsWith( "Failed to obtain connection towards WRITE server" ) ); } } private void testRediscoveryWhenStale( AccessMode mode ) { ConnectionPool connections = mock( ConnectionPool.class ); when( connections.acquire( LOCAL_DEFAULT ) ).thenReturn( mock( PooledConnection.class ) ); RoutingTable routingTable = newStaleRoutingTableMock( mode ); Rediscovery rediscovery = newRediscoveryMock(); LoadBalancer loadBalancer = new LoadBalancer( connections, routingTable, rediscovery, DEV_NULL_LOGGER ); verify( rediscovery ).lookupClusterComposition( routingTable, connections ); assertNotNull( loadBalancer.acquireConnection( mode ) ); verify( routingTable ).isStaleFor( mode ); verify( rediscovery, times( 2 ) ).lookupClusterComposition( routingTable, connections ); } private void testNoRediscoveryWhenNotStale( AccessMode staleMode, AccessMode notStaleMode ) { ConnectionPool connections = mock( ConnectionPool.class ); when( connections.acquire( LOCAL_DEFAULT ) ).thenReturn( mock( PooledConnection.class ) ); RoutingTable routingTable = newStaleRoutingTableMock( staleMode ); Rediscovery rediscovery = newRediscoveryMock(); LoadBalancer loadBalancer = new LoadBalancer( connections, routingTable, rediscovery, DEV_NULL_LOGGER ); verify( rediscovery ).lookupClusterComposition( routingTable, connections ); assertNotNull( loadBalancer.acquireConnection( notStaleMode ) ); verify( routingTable ).isStaleFor( notStaleMode ); verify( rediscovery ).lookupClusterComposition( routingTable, connections ); } private LoadBalancer setupLoadBalancer( PooledConnection writerConn, PooledConnection readConn ) { return setupLoadBalancer( writerConn, readConn, mock( Rediscovery.class ) ); } private LoadBalancer setupLoadBalancer( PooledConnection writerConn, PooledConnection readConn, Rediscovery rediscovery ) { BoltServerAddress writer = mock( BoltServerAddress.class ); BoltServerAddress reader = mock( BoltServerAddress.class ); ConnectionPool connPool = mock( ConnectionPool.class ); when( connPool.acquire( writer ) ).thenReturn( writerConn ); when( connPool.acquire( reader ) ).thenReturn( readConn ); RoundRobinAddressSet writerAddrs = mock( RoundRobinAddressSet.class ); when( writerAddrs.next() ).thenReturn( writer ); RoundRobinAddressSet readerAddrs = mock( RoundRobinAddressSet.class ); when( readerAddrs.next() ).thenReturn( reader ); RoutingTable routingTable = mock( RoutingTable.class ); when( routingTable.readers() ).thenReturn( readerAddrs ); when( routingTable.writers() ).thenReturn( writerAddrs ); return new LoadBalancer( connPool, routingTable, rediscovery, DEV_NULL_LOGGER ); } private static Session newSession( LoadBalancer loadBalancer ) { SleeplessClock clock = new SleeplessClock(); RetryLogic retryLogic = new ExponentialBackoffRetryLogic( RetrySettings.DEFAULT, clock, DEV_NULL_LOGGING ); return new NetworkSession( loadBalancer, AccessMode.WRITE, retryLogic, DEV_NULL_LOGGING ); } private static PooledConnection newConnectionWithFailingSync( BoltServerAddress address ) { PooledConnection connection = mock( PooledConnection.class ); doReturn( true ).when( connection ).isOpen(); doReturn( address ).when( connection ).boltServerAddress(); ServiceUnavailableException closeError = new ServiceUnavailableException( "Oh!" ); doThrow( closeError ).when( connection ).sync(); return connection; } private static RoutingTable newStaleRoutingTableMock( AccessMode mode ) { RoutingTable routingTable = mock( RoutingTable.class ); when( routingTable.isStaleFor( mode ) ).thenReturn( true ); when( routingTable.update( any( ClusterComposition.class ) ) ).thenReturn( new HashSet<BoltServerAddress>() ); RoundRobinAddressSet addresses = new RoundRobinAddressSet(); addresses.update( new HashSet<>( singletonList( LOCAL_DEFAULT ) ), new HashSet<BoltServerAddress>() ); when( routingTable.readers() ).thenReturn( addresses ); when( routingTable.writers() ).thenReturn( addresses ); return routingTable; } private static Rediscovery newRediscoveryMock() { Rediscovery rediscovery = mock( Rediscovery.class ); Set<BoltServerAddress> noServers = Collections.<BoltServerAddress>emptySet(); ClusterComposition clusterComposition = new ClusterComposition( 1, noServers, noServers, noServers ); when( rediscovery.lookupClusterComposition( any( RoutingTable.class ), any( ConnectionPool.class ) ) ) .thenReturn( clusterComposition ); return rediscovery; } }