/** * Copyright (c) 2002-2012 "Neo Technology," * Network Engine for Objects in Lund AB [http://neotechnology.com] * * This file is part of Neo4j. * * Neo4j is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package org.neo4j.cluster; import java.io.PrintWriter; import java.io.StringWriter; import java.net.URI; import java.util.ArrayList; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import ch.qos.logback.classic.LoggerContext; import ch.qos.logback.classic.joran.JoranConfigurator; import ch.qos.logback.core.joran.spi.JoranException; import org.neo4j.cluster.com.message.Message; import org.neo4j.cluster.com.message.MessageType; import org.neo4j.cluster.protocol.atomicbroadcast.multipaxos.InMemoryAcceptorInstanceStore; import org.neo4j.cluster.protocol.cluster.ClusterConfiguration; import org.neo4j.cluster.protocol.election.ServerIdElectionCredentialsProvider; import org.neo4j.cluster.statemachine.StateTransitionLogger; import org.neo4j.cluster.timeout.MessageTimeoutStrategy; import org.neo4j.kernel.impl.util.StringLogger; import org.neo4j.kernel.logging.LogbackService; import org.neo4j.kernel.logging.Logging; import org.slf4j.LoggerFactory; /** * This mocks message delivery, message loss, and time for timeouts and message latency * between protocol servers. */ public class NetworkMock { Map<String, TestProtocolServer> participants = new LinkedHashMap<String, TestProtocolServer>(); private List<MessageDelivery> messageDeliveries = new ArrayList<MessageDelivery>(); private long now = 0; private long tickDuration; private final MultipleFailureLatencyStrategy strategy; private MessageTimeoutStrategy timeoutStrategy; private Logging logging; protected final StringLogger logger; public NetworkMock( long tickDuration, MultipleFailureLatencyStrategy strategy, MessageTimeoutStrategy timeoutStrategy ) { this.tickDuration = tickDuration; this.strategy = strategy; this.timeoutStrategy = timeoutStrategy; this.logging = new LogbackService( null, (LoggerContext) LoggerFactory.getILoggerFactory() ); logger = logging.getLogger( NetworkMock.class.getName() ); } public TestProtocolServer addServer( URI serverId ) { TestProtocolServer server = newTestProtocolServer( serverId ); debug( serverId.toString(), "joins network" ); participants.put( serverId.toString(), server ); return server; } protected TestProtocolServer newTestProtocolServer( URI serverId ) { LoggerContext loggerContext = new LoggerContext(); loggerContext.putProperty( "host", serverId.toString() ); JoranConfigurator configurator = new JoranConfigurator(); configurator.setContext( loggerContext ); try { configurator.doConfigure( getClass().getResource( "/test-logback.xml" ) ); } catch ( JoranException e ) { throw new IllegalStateException( "Failed to configure logging", e ); } Logging logging = new LogbackService( null, loggerContext ); ProtocolServerFactory protocolServerFactory = new MultiPaxosServerFactory( new ClusterConfiguration( "default" ), logging ); ServerIdElectionCredentialsProvider electionCredentialsProvider = new ServerIdElectionCredentialsProvider(); electionCredentialsProvider.listeningAt( serverId ); TestProtocolServer protocolServer = new TestProtocolServer( timeoutStrategy, protocolServerFactory, serverId, new InMemoryAcceptorInstanceStore(), electionCredentialsProvider ); protocolServer.addStateTransitionListener( new StateTransitionLogger( logging ) ); return protocolServer; } private void debug( String participant, String string ) { logger.info( "=== " + participant + " " + string ); } public void removeServer( String serverId ) { debug( serverId, "leaves network" ); participants.remove( serverId ); } public int tick() { // Deliver messages whose delivery time has passed now += tickDuration; // logger.debug( "tick:"+now ); Iterator<MessageDelivery> iter = messageDeliveries.iterator(); while ( iter.hasNext() ) { MessageDelivery messageDelivery = iter.next(); if ( messageDelivery.getMessageDeliveryTime() <= now ) { long delay = strategy.messageDelay( messageDelivery.getMessage(), messageDelivery.getServer().toString() ); if ( delay != NetworkLatencyStrategy.LOST ) { messageDelivery.getServer().process( messageDelivery.getMessage() ); } iter.remove(); } } // Check and trigger timeouts for ( TestProtocolServer testServer : participants.values() ) { testServer.tick( now ); } // Get all sent messages from all test servers List<Message> messages = new ArrayList<Message>(); for ( TestProtocolServer testServer : participants.values() ) { testServer.sendMessages( messages ); } // Now send them and figure out latency for ( Message message : messages ) { String to = message.getHeader( Message.TO ); if ( to.equals( Message.BROADCAST ) ) { for ( Map.Entry<String, TestProtocolServer> testServer : participants.entrySet() ) { if ( !testServer.getKey().equals( message.getHeader( Message.FROM ) ) ) { long delay = strategy.messageDelay( message, testServer.getKey() ); if ( delay == NetworkLatencyStrategy.LOST ) { logger.info( "Broadcasted message to " + testServer.getKey() + " was lost" ); } else { logger.info( "Broadcast to " + testServer.getKey() + ": " + message ); messageDeliveries.add( new MessageDelivery( now + delay, message, testServer.getValue() ) ); } } } } else { long delay = 0; if ( message.getHeader( Message.TO ).equals( message.getHeader( Message.FROM ) ) ) { logger.info( "Sending message to itself; zero latency" ); } else { delay = strategy.messageDelay( message, to ); } if ( delay == NetworkLatencyStrategy.LOST ) { logger.info( "Send message to " + to + " was lost" ); } else { TestProtocolServer server = participants.get( to ); logger.info( "Send to " + to + ": " + message ); messageDeliveries.add( new MessageDelivery( now + delay, message, server ) ); } } } return messageDeliveries.size(); } public void tick( int iterations ) { for ( int i = 0; i < iterations; i++ ) { tick(); } } public void tickUntilDone() { while ( tick() + totalCurrentTimeouts() > 0 ) { } } private int totalCurrentTimeouts() { int count = 0; for ( TestProtocolServer testProtocolServer : participants.values() ) { count += testProtocolServer.getTimeouts().getTimeouts().size(); } return count; } @Override public String toString() { StringWriter stringWriter = new StringWriter(); PrintWriter out = new PrintWriter( stringWriter, true ); out.printf( "Now:%s \n", now ); out.printf( "Pending messages:%s \n", messageDeliveries.size() ); out.printf( "Pending timeouts:%s \n", totalCurrentTimeouts() ); for ( TestProtocolServer testProtocolServer : participants.values() ) { out.println( " " + testProtocolServer ); } return stringWriter.toString(); } public Long getTime() { return now; } public List<TestProtocolServer> getServers() { return new ArrayList<TestProtocolServer>( participants.values() ); } public MultipleFailureLatencyStrategy getNetworkLatencyStrategy() { return strategy; } public MessageTimeoutStrategy getTimeoutStrategy() { return timeoutStrategy; } private static class MessageDelivery { long messageDeliveryTime; Message<? extends MessageType> message; TestProtocolServer server; private MessageDelivery( long messageDeliveryTime, Message<? extends MessageType> message, TestProtocolServer server ) { this.messageDeliveryTime = messageDeliveryTime; this.message = message; this.server = server; } public long getMessageDeliveryTime() { return messageDeliveryTime; } public Message<? extends MessageType> getMessage() { return message; } public TestProtocolServer getServer() { return server; } @Override public String toString() { return "Deliver " + message.getMessageType().name() + " to " + server.getServer().getServerId() + " at " + messageDeliveryTime; } } }