/** * Copyright (c) 2002-2012 "Neo Technology," * Network Engine for Objects in Lund AB [http://neotechnology.com] * * This file is part of Neo4j. * * Neo4j is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package org.neo4j.cluster.protocol.cluster; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Queue; import java.util.Random; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeoutException; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.neo4j.cluster.ConnectedStateMachines; import org.neo4j.cluster.FixedNetworkLatencyStrategy; import org.neo4j.cluster.MultipleFailureLatencyStrategy; import org.neo4j.cluster.NetworkMock; import org.neo4j.cluster.ScriptableNetworkFailureLatencyStrategy; import org.neo4j.cluster.TestProtocolServer; import org.neo4j.cluster.protocol.atomicbroadcast.AtomicBroadcast; import org.neo4j.cluster.protocol.atomicbroadcast.AtomicBroadcastListener; import org.neo4j.cluster.protocol.atomicbroadcast.AtomicBroadcastSerializer; import org.neo4j.cluster.protocol.atomicbroadcast.Payload; import org.neo4j.cluster.protocol.heartbeat.Heartbeat; import org.neo4j.cluster.protocol.heartbeat.HeartbeatContext; import org.neo4j.cluster.protocol.heartbeat.HeartbeatListener; import org.neo4j.cluster.protocol.heartbeat.HeartbeatMessage; import org.neo4j.cluster.statemachine.State; import org.neo4j.cluster.timeout.FixedTimeoutStrategy; import org.neo4j.cluster.timeout.MessageTimeoutStrategy; import org.neo4j.helpers.NamedThreadFactory; import org.neo4j.helpers.collection.Iterables; import org.neo4j.test.LoggerRule; /** * Base class for cluster tests */ public class ClusterMockTest { public static NetworkMock DEFAULT_NETWORK() { return new NetworkMock( 10, new MultipleFailureLatencyStrategy( new FixedNetworkLatencyStrategy( 10 ), new ScriptableNetworkFailureLatencyStrategy() ), new MessageTimeoutStrategy( new FixedTimeoutStrategy( 500 ) ) .timeout( HeartbeatMessage.sendHeartbeat, 200 ) ); } List<TestProtocolServer> servers = new ArrayList<TestProtocolServer>(); List<Cluster> out = new ArrayList<Cluster>(); List<Cluster> in = new ArrayList<Cluster>(); @Rule public LoggerRule logger = new LoggerRule(); public NetworkMock network; ClusterTestScript script; ExecutorService executor; @Before public void setup() { executor = Executors.newSingleThreadExecutor( new NamedThreadFactory( "Configuration output" ) ); } @After public void tearDown() { executor.shutdownNow(); } protected void testCluster( int nrOfServers, NetworkMock mock, ClusterTestScript script ) throws ExecutionException, InterruptedException, URISyntaxException, TimeoutException { this.script = script; network = mock; servers.clear(); out.clear(); in.clear(); for ( int i = 0; i < nrOfServers; i++ ) { final URI uri = new URI( "server" + (i + 1) ); TestProtocolServer server = network.addServer( uri ); final Cluster cluster = server.newClient( Cluster.class ); clusterStateListener( uri, cluster ); server.newClient( Heartbeat.class ).addHeartbeatListener( new HeartbeatListener() { @Override public void failed( URI server ) { logger.getLogger().warn( uri + ": Failed:" + server ); } @Override public void alive( URI server ) { logger.getLogger().info( uri + ": Alive:" + server ); } } ); server.newClient( AtomicBroadcast.class ).addAtomicBroadcastListener( new AtomicBroadcastListener() { AtomicBroadcastSerializer serializer = new AtomicBroadcastSerializer(); @Override public void receive( Payload value ) { try { logger.getLogger().info( uri + " received: " + serializer.receive( value ) ); } catch ( IOException e ) { e.printStackTrace(); } catch ( ClassNotFoundException e ) { e.printStackTrace(); } } } ); servers.add( server ); out.add( cluster ); } // Run test for ( int i = 0; i < script.rounds(); i++ ) { logger.getLogger().info( "Round " + i + ", time:" + network.getTime() ); script.tick( network.getTime() ); network.tick(); } // Let messages settle network.tick( 100 ); verifyConfigurations(); logger.getLogger().info( "All nodes leave" ); // All leave for ( Cluster cluster : new ArrayList<Cluster>( in ) ) { logger.getLogger().info( "Leaving:" + cluster ); cluster.leave(); in.remove( cluster ); network.tick( 400 ); } verifyConfigurations(); } private void clusterStateListener( final URI uri, final Cluster cluster ) { cluster.addClusterListener( new ClusterListener() { @Override public void enteredCluster( ClusterConfiguration clusterConfiguration ) { logger.getLogger().info( uri + " entered cluster:" + clusterConfiguration.getMembers() ); in.add( cluster ); } @Override public void joinedCluster( URI member ) { logger.getLogger().info( uri + " sees a join:" + member.toString() ); } @Override public void leftCluster( URI member ) { logger.getLogger().info( uri + " sees a leave:" + member.toString() ); } @Override public void leftCluster() { logger.getLogger().info( uri + " left cluster" ); out.add( cluster ); } @Override public void elected( String role, URI electedMember ) { logger.getLogger().info( uri + " sees an election: " + electedMember + " elected as " + role ); } } ); } public void verifyConfigurations() { logger.getLogger().info( "Verify configurations" ); List<URI> members = null; Map<String, URI> roles = null; List<URI> failed = null; int foundConfiguration = 0; List<TestProtocolServer> protocolServers = network.getServers(); List<AssertionError> errors = new ArrayList<AssertionError>(); for ( int j = 0; j < protocolServers.size(); j++ ) { ConnectedStateMachines connectedStateMachines = protocolServers.get( j ) .getServer() .getConnectedStateMachines(); State<?, ?> clusterState = connectedStateMachines.getStateMachine( ClusterMessage.class ).getState(); if ( !clusterState.equals( ClusterState.entered ) ) { logger.getLogger().warn( "Instance " + (j + 1) + " is not in the cluster (" + clusterState + ")" ); continue; } ClusterContext context = (ClusterContext) connectedStateMachines.getStateMachine( ClusterMessage.class ) .getContext(); HeartbeatContext heartbeatContext = (HeartbeatContext) connectedStateMachines.getStateMachine( HeartbeatMessage.class ).getContext(); ClusterConfiguration clusterConfiguration = context.getConfiguration(); if ( !clusterConfiguration.getMembers().isEmpty() ) { logger.getLogger().info( " Server " + (j + 1) + ": Cluster:" + clusterConfiguration.getMembers() + "," + " Roles:" + clusterConfiguration.getRoles() + ", Failed:" + heartbeatContext.getFailed() ); foundConfiguration++; if ( members == null ) { members = clusterConfiguration.getMembers(); roles = clusterConfiguration.getRoles(); failed = heartbeatContext.getFailed(); } else { try { assertEquals( "Config for server" + (j + 1) + " is wrong", new HashSet<URI>( members ), new HashSet<URI>( clusterConfiguration .getMembers() ) ); } catch ( AssertionError e ) { errors.add( e ); } try { assertEquals( "Roles for server" + (j + 1) + " is wrong", roles, clusterConfiguration .getRoles() ); } catch ( AssertionError e ) { errors.add( e ); } try { assertEquals( "Failed for server" + (j + 1) + " is wrong", failed, heartbeatContext.getFailed () ); } catch ( AssertionError e ) { errors.add( e ); } } } } if ( !errors.isEmpty() ) { for ( AssertionError error : errors ) { logger.getLogger().error( error.toString() ); } throw errors.get( 0 ); } if ( foundConfiguration > 0 ) { assertEquals( "Nr of found active members does not match configuration size", members.size(), foundConfiguration ); } assertEquals( "In:" + in + ", Out:" + out, protocolServers.size(), Iterables.count( Iterables.<Cluster, List<Cluster>>flatten( in, out ) ) ); } public interface ClusterTestScript { int rounds(); void tick( long time ); } public class ClusterTestScriptDSL implements ClusterTestScript { public abstract class ClusterAction implements Runnable { public long time; } private Queue<ClusterAction> actions = new LinkedList<ClusterAction>(); private AtomicBroadcastSerializer serializer = new AtomicBroadcastSerializer(); private int rounds = 100; private long now = 0; public ClusterTestScriptDSL rounds( int n ) { rounds = n; return this; } public ClusterTestScriptDSL join( int time, final int joinServer ) { return addAction( new ClusterAction() { @Override public void run() { Cluster joinCluster = servers.get( joinServer - 1 ).newClient( Cluster.class ); for ( final Cluster cluster : out ) { if ( cluster.equals( joinCluster ) ) { out.remove( cluster ); logger.getLogger().info( "Join:" + cluster.toString() ); if ( in.isEmpty() ) { cluster.create( "default" ); } else { try { final Future<ClusterConfiguration> result = cluster.join( new URI( in.get( 0 ) .toString() ) ); executor.submit( new Runnable() { @Override public void run() { try { ClusterConfiguration clusterConfiguration = result.get(); logger.getLogger().info( "**** Cluster configuration:" + clusterConfiguration ); } catch ( Exception e ) { logger.getLogger().info( "**** Node could not join cluster:" + e .getMessage() ); out.add( cluster ); } } } ); } catch ( URISyntaxException e ) { e.printStackTrace(); } } break; } } } }, time ); } public ClusterTestScriptDSL leave( long time, final int leaveServer ) { return addAction( new ClusterAction() { @Override public void run() { Cluster leaveCluster = servers.get( leaveServer - 1 ).newClient( Cluster.class ); for ( Cluster cluster : in ) { if ( cluster.equals( leaveCluster ) ) { in.remove( cluster ); cluster.leave(); logger.getLogger().info( "Leave:" + cluster.toString() ); break; } } } }, time ); } public ClusterTestScriptDSL down( int time, final int serverDown ) { return addAction( new ClusterAction() { @Override public void run() { Cluster server = servers.get( serverDown - 1 ).newClient( Cluster.class ); network.getNetworkLatencyStrategy().getStrategy( ScriptableNetworkFailureLatencyStrategy.class ) .nodeIsDown( server.toString() ); logger.getLogger().info( server + " is down" ); } }, time ); } public ClusterTestScriptDSL up( int time, final int serverUp ) { return addAction( new ClusterAction() { @Override public void run() { Cluster server = servers.get( serverUp - 1 ).newClient( Cluster.class ); network.getNetworkLatencyStrategy().getStrategy( ScriptableNetworkFailureLatencyStrategy.class ) .nodeIsUp( server .toString() ); logger.getLogger().info( server + " is up" ); } }, time ); } public ClusterTestScriptDSL broadcast( int time, final int server, final Object value ) { return addAction( new ClusterAction() { @Override public void run() { AtomicBroadcast broadcast = servers.get( server - 1 ).newClient( AtomicBroadcast.class ); try { broadcast.broadcast( serializer.broadcast( value ) ); } catch ( IOException e ) { e.printStackTrace(); } } }, time ); } public ClusterTestScriptDSL sleep( final int sleepTime ) { return addAction( new ClusterAction() { @Override public void run() { logger.getLogger().info( "Slept for " + sleepTime ); } }, sleepTime ); } public ClusterTestScriptDSL message( int time, final String msg ) { return addAction( new ClusterAction() { @Override public void run() { logger.getLogger().info( msg ); } }, time ); } public ClusterTestScriptDSL verifyConfigurations( long time ) { return addAction( new ClusterAction() { @Override public void run() { ClusterMockTest.this.verifyConfigurations(); } }, time ); } private ClusterTestScriptDSL addAction( ClusterAction action, long time ) { action.time = now + time; actions.offer( action ); now += time; return this; } @Override public int rounds() { return rounds; } @Override public void tick( long time ) { while ( !actions.isEmpty() && actions.peek().time == time ) { actions.poll().run(); } } public ClusterTestScriptDSL getRoles( int time, final Map<String, URI> roles ) { return addAction( new ClusterAction() { @Override public void run() { ClusterMockTest.this.getRoles( roles ); } }, 0 ); } public ClusterTestScriptDSL verifyCoordinatorRoleSwitched( final Map<String, URI> comparedTo ) { return addAction( new ClusterAction() { @Override public void run() { HashMap<String, URI> roles = new HashMap<String, URI>(); ClusterMockTest.this.getRoles( roles ); URI oldCoordinator = comparedTo.get( ClusterConfiguration.COORDINATOR ); URI newCoordinator = roles.get( ClusterConfiguration.COORDINATOR ); assertNotNull( "Should have had a coordinator before bringing it down", oldCoordinator ); assertNotNull( "Should have a new coordinator after the previous failed", newCoordinator ); assertTrue( "Should have elected a new coordinator", !oldCoordinator.equals( newCoordinator ) ); } }, 0 ); } } public class ClusterTestScriptRandom implements ClusterTestScript { private final long seed; private final Random random; public ClusterTestScriptRandom( long seed ) { if ( seed == -1 ) { seed = System.nanoTime(); } this.seed = seed; random = new Random( seed ); } @Override public int rounds() { return 300; } @Override public void tick( long time ) { if ( time >= (rounds() - 100) * 10 ) { return; } if ( time == 0 ) { logger.getLogger().info( "Random seed:" + seed + "L" ); } if ( random.nextDouble() >= 0.8 ) { double inOrOut = (in.size() - out.size()) / ((double) servers.size()); double whatToDo = random.nextDouble() + inOrOut; logger.getLogger().info( "What to do:" + whatToDo ); if ( whatToDo < 0.5 && !out.isEmpty() ) { int idx = random.nextInt( out.size() ); final Cluster cluster = out.remove( idx ); if ( in.isEmpty() ) { cluster.create( "default" ); } else { try { final Future<ClusterConfiguration> result = cluster.join( new URI( in.get( 0 ) .toString() ) ); executor.submit( new Runnable() { @Override public void run() { try { ClusterConfiguration clusterConfiguration = result.get(); logger.getLogger().info( "**** Cluster configuration:" + clusterConfiguration ); } catch ( Exception e ) { logger.getLogger().info( "**** Node could not join cluster:" + e .getMessage() ); out.add( cluster ); } } } ); } catch ( URISyntaxException e ) { e.printStackTrace(); } } logger.getLogger().info( "Enter cluster:" + cluster.toString() ); } else if ( !in.isEmpty() ) { int idx = random.nextInt( in.size() ); Cluster cluster = in.remove( idx ); cluster.leave(); logger.getLogger().info( "Leave cluster:" + cluster.toString() ); } } } } private void getRoles( Map<String, URI> roles ) { List<TestProtocolServer> protocolServers = network.getServers(); for ( int j = 0; j < protocolServers.size(); j++ ) { ConnectedStateMachines connectedStateMachines = protocolServers.get( j ) .getServer() .getConnectedStateMachines(); State<?, ?> clusterState = connectedStateMachines.getStateMachine( ClusterMessage.class ).getState(); if ( !clusterState.equals( ClusterState.entered ) ) { logger.getLogger().warn( "Instance " + (j + 1) + " is not in the cluster (" + clusterState + ")" ); continue; } ClusterContext context = (ClusterContext) connectedStateMachines.getStateMachine( ClusterMessage.class ) .getContext(); ClusterConfiguration clusterConfiguration = context.getConfiguration(); roles.putAll( clusterConfiguration.getRoles() ); } } }