/**
* Copyright (c) 2002-2011 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.kernel.ha;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.jboss.netty.bootstrap.ServerBootstrap;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelException;
import org.jboss.netty.channel.ChannelFactory;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelHandler;
import org.jboss.netty.channel.group.ChannelGroup;
import org.jboss.netty.channel.group.DefaultChannelGroup;
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory;
import org.neo4j.helpers.Pair;
import org.neo4j.kernel.impl.util.StringLogger;
/**
* Sits on the master side, receiving serialized requests from slaves (via
* {@link MasterClient}). Delegates actual work to {@link MasterImpl}.
*/
public class MasterServer extends CommunicationProtocol implements ChannelPipelineFactory
{
private final static int DEAD_CONNECTIONS_CHECK_INTERVAL = 3;
private final static int MAX_NUMBER_OF_CONCURRENT_TRANSACTIONS = 200;
private final ChannelFactory channelFactory;
private final ServerBootstrap bootstrap;
private final Master realMaster;
private final ChannelGroup channelGroup;
private final ScheduledExecutorService deadConnectionsPoller;
private final Map<Channel, SlaveContext> connectedSlaveChannels =
new HashMap<Channel, SlaveContext>();
private final Map<Channel, Pair<ChannelBuffer, ByteBuffer>> channelBuffers =
new HashMap<Channel, Pair<ChannelBuffer,ByteBuffer>>();
private final ExecutorService executor;
private final StringLogger msgLog;
private final Map<Channel, PartialRequest> partialRequests =
Collections.synchronizedMap( new HashMap<Channel, PartialRequest>() );
public MasterServer( Master realMaster, final int port, String storeDir )
{
this.realMaster = realMaster;
this.msgLog = StringLogger.getLogger( storeDir + "/messages.log" );
executor = Executors.newCachedThreadPool();
channelFactory = new NioServerSocketChannelFactory(
executor, executor, MAX_NUMBER_OF_CONCURRENT_TRANSACTIONS );
bootstrap = new ServerBootstrap( channelFactory );
bootstrap.setPipelineFactory( this );
/*
executor.execute( new Runnable()
{
public void run()
{
Channel channel;
try
{
channel = bootstrap.bind( new InetSocketAddress( port ) );
}
catch ( ChannelException e )
{
msgLog.logMessage( "Failed to bind master server to port " + port, e );
return;
}
// Add the "server" channel
channelGroup.add( channel );
msgLog.logMessage( "Master server bound to " + port, true );
}
} );
//*/
Channel channel;
try
{
channel = bootstrap.bind( new InetSocketAddress( port ) );
}
catch ( ChannelException e )
{
msgLog.logMessage( "Failed to bind master server to port " + port, e );
executor.shutdown();
throw e;
}
channelGroup = new DefaultChannelGroup();
// Add the "server" channel
channelGroup.add( channel );
msgLog.logMessage( "Master server bound to " + port, true );
deadConnectionsPoller = new ScheduledThreadPoolExecutor( 1 );
deadConnectionsPoller.scheduleWithFixedDelay( new Runnable()
{
public void run()
{
checkForDeadChannels();
}
}, DEAD_CONNECTIONS_CHECK_INTERVAL, DEAD_CONNECTIONS_CHECK_INTERVAL, TimeUnit.SECONDS );
}
public ChannelPipeline getPipeline() throws Exception
{
ChannelPipeline pipeline = Channels.pipeline();
addLengthFieldPipes( pipeline );
pipeline.addLast( "serverHandler", new ServerHandler() );
return pipeline;
}
private class ServerHandler extends SimpleChannelHandler
{
@Override
public void messageReceived( ChannelHandlerContext ctx, MessageEvent event )
throws Exception
{
try
{
ChannelBuffer message = (ChannelBuffer) event.getMessage();
handleRequest( realMaster, message, event.getChannel() );
}
catch ( Exception e )
{
e.printStackTrace();
throw e;
}
}
@Override
public void exceptionCaught( ChannelHandlerContext ctx, ExceptionEvent e ) throws Exception
{
e.getCause().printStackTrace();
}
}
@SuppressWarnings( "unchecked" )
private void handleRequest( Master realMaster, ChannelBuffer buffer,
final Channel channel ) throws IOException
{
// TODO Too long method, refactor please
byte continuation = buffer.readByte();
if ( continuation == ChunkingChannelBuffer.CONTINUATION_MORE )
{
PartialRequest partialRequest = partialRequests.get( channel );
if ( partialRequest == null )
{
// This is the first chunk
RequestType type = RequestType.values()[buffer.readByte()];
SlaveContext context = null;
if ( type.includesSlaveContext() )
{
context = readSlaveContext( buffer );
}
Pair<ChannelBuffer, ByteBuffer> targetBuffers = mapSlave( channel, context );
partialRequest = new PartialRequest( type, context, targetBuffers );
partialRequests.put( channel, partialRequest );
}
partialRequest.add( buffer );
}
else
{
PartialRequest partialRequest = partialRequests.remove( channel );
RequestType type = null;
SlaveContext context = null;
Pair<ChannelBuffer, ByteBuffer> targetBuffers;
ChannelBuffer bufferToReadFrom = null;
ChannelBuffer bufferToWriteTo = null;
if ( partialRequest == null )
{
type = RequestType.values()[buffer.readByte()];
if ( type.includesSlaveContext() )
{
context = readSlaveContext( buffer );
}
targetBuffers = mapSlave( channel, context );
bufferToReadFrom = buffer;
bufferToWriteTo = targetBuffers.first();
}
else
{
type = partialRequest.type;
context = partialRequest.slaveContext;
targetBuffers = partialRequest.buffers;
partialRequest.add( buffer );
bufferToReadFrom = targetBuffers.first();
bufferToWriteTo = ChannelBuffers.dynamicBuffer();
}
bufferToWriteTo.clear();
final ChunkingChannelBuffer chunkingBuffer = new ChunkingChannelBuffer( bufferToWriteTo, channel, MAX_FRAME_LENGTH );
final Response<?> response = type.caller.callMaster( realMaster, context,
bufferToReadFrom, chunkingBuffer );
final ByteBuffer targetByteBuffer = targetBuffers.other();
final RequestType finalType = type;
final SlaveContext finalContext = context;
executor.submit( new Runnable()
{
public void run()
{
try
{
finalType.serializer.write( response.response(), chunkingBuffer );
if ( finalType.includesSlaveContext() )
{
writeTransactionStreams( response.transactions(), chunkingBuffer, targetByteBuffer );
}
chunkingBuffer.done();
if ( finalType == RequestType.FINISH || finalType == RequestType.PULL_UPDATES )
{
unmapSlave( channel, finalContext );
}
}
catch ( IOException e )
{
e.printStackTrace();
throw new RuntimeException( e );
}
catch ( RuntimeException e )
{
e.printStackTrace();
throw e;
}
}
} );
}
}
protected Pair<ChannelBuffer, ByteBuffer> mapSlave( Channel channel, SlaveContext slave )
{
channelGroup.add( channel );
Pair<ChannelBuffer, ByteBuffer> buffer = null;
synchronized ( connectedSlaveChannels )
{
if ( slave != null )
{
connectedSlaveChannels.put( channel, slave );
}
buffer = channelBuffers.get( channel );
if ( buffer == null )
{
buffer = Pair.of( ChannelBuffers.dynamicBuffer(), ByteBuffer.allocateDirect( 1*1024*1024 ) );
channelBuffers.put( channel, buffer );
}
buffer.first().clear();
}
return buffer;
}
protected void unmapSlave( Channel channel, SlaveContext slave )
{
synchronized ( connectedSlaveChannels )
{
connectedSlaveChannels.remove( channel );
}
}
public void shutdown()
{
// Close all open connections
deadConnectionsPoller.shutdown();
msgLog.logMessage( "Master server shutdown, closing all channels", true );
channelGroup.close().awaitUninterruptibly();
executor.shutdown();
// TODO This should work, but blocks with busy wait sometimes
// channelFactory.releaseExternalResources();
}
private void checkForDeadChannels()
{
synchronized ( connectedSlaveChannels )
{
Collection<Channel> channelsToRemove = new ArrayList<Channel>();
for ( Map.Entry<Channel, SlaveContext> entry : connectedSlaveChannels.entrySet() )
{
if ( !channelIsOpen( entry.getKey() ) )
{
System.out.println( "Found dead channel " + entry.getKey() + ", " + entry.getValue() );
realMaster.finishTransaction( entry.getValue() );
System.out.println( "Removed " + entry.getKey() + ", " + entry.getValue() );
channelsToRemove.add( entry.getKey() );
}
}
for ( Channel channel : channelsToRemove )
{
connectedSlaveChannels.remove( channel );
channelBuffers.remove( channel );
partialRequests.remove( channel );
}
}
}
private boolean channelIsOpen( Channel channel )
{
/**
* "open" is defined as the lowest means of connectedness
* "connected" may be that data is actually sent or something
*/
return channel.isConnected() && channel.isOpen();
}
// =====================================================================
// Just some methods which aren't really used when running an HA cluster,
// but exposed so that other tools can reach that information.
// =====================================================================
public Map<Integer, Collection<SlaveContext>> getSlaveInformation()
{
// Which slaves are connected a.t.m?
Set<Integer> machineIds = new HashSet<Integer>();
synchronized ( connectedSlaveChannels )
{
for ( SlaveContext context : this.connectedSlaveChannels.values() )
{
machineIds.add( context.machineId() );
}
}
// Insert missing slaves into the map so that all connected slave
// are in the returned map
Map<Integer, Collection<SlaveContext>> ongoingTransactions =
((MasterImpl) realMaster).getOngoingTransactions();
for ( Integer machineId : machineIds )
{
if ( !ongoingTransactions.containsKey( machineId ) )
{
ongoingTransactions.put( machineId, Collections.<SlaveContext>emptyList() );
}
}
return new TreeMap<Integer, Collection<SlaveContext>>( ongoingTransactions );
}
static class PartialRequest
{
final SlaveContext slaveContext;
final Pair<ChannelBuffer, ByteBuffer> buffers;
final RequestType type;
public PartialRequest( RequestType type, SlaveContext slaveContext, Pair<ChannelBuffer, ByteBuffer> buffers )
{
this.type = type;
this.slaveContext = slaveContext;
this.buffers = buffers;
}
public void add( ChannelBuffer buffer )
{
this.buffers.first().writeBytes( buffer );
}
}
}