/** * Copyright (c) 2002-2012 "Neo Technology," * Network Engine for Objects in Lund AB [http://neotechnology.com] * * This file is part of Neo4j. * * Neo4j is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package org.neo4j.com; import static org.neo4j.com.DechunkingChannelBuffer.assertSameProtocolVersion; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; import java.net.InetSocketAddress; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import org.jboss.netty.bootstrap.ServerBootstrap; import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.buffer.ChannelBuffers; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelException; import org.jboss.netty.channel.ChannelFactory; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelPipeline; import org.jboss.netty.channel.ChannelPipelineFactory; import org.jboss.netty.channel.ChannelStateEvent; import org.jboss.netty.channel.Channels; import org.jboss.netty.channel.ExceptionEvent; import org.jboss.netty.channel.MessageEvent; import org.jboss.netty.channel.SimpleChannelHandler; import org.jboss.netty.channel.group.ChannelGroup; import org.jboss.netty.channel.group.DefaultChannelGroup; import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory; import org.neo4j.com.RequestContext.Tx; import org.neo4j.helpers.Exceptions; import org.neo4j.helpers.NamedThreadFactory; import org.neo4j.helpers.Pair; import org.neo4j.helpers.Triplet; import org.neo4j.helpers.collection.IteratorUtil; import org.neo4j.kernel.configuration.Config; import org.neo4j.kernel.impl.nioneo.store.StoreId; import org.neo4j.kernel.impl.util.StringLogger; import org.neo4j.kernel.lifecycle.Lifecycle; /** * Receives requests from {@link Client clients}. Delegates actual work to an instance * of a specified communication interface, injected in the constructor. * * frameLength vs. chunkSize: frameLength is the maximum and hardcoded size in each * Netty buffer created by this server and handed off to a {@link Client}. If the * client has got a smaller frameLength than this server it will fail on reading a frame * that is bigger than what its frameLength. * chunkSize is the max size a buffer will have before it's sent off and a new buffer * allocated to continue writing to. * frameLength should be a constant for an implementation and must have the same value * on server as well as clients connecting to that server, whereas chunkSize very well * can be configurable and vary between server and client. * * @see Client */ public abstract class Server<T, R> extends Protocol implements ChannelPipelineFactory, Lifecycle { public interface Configuration { long getOldChannelThreshold(); int getMaxConcurrentTransactions(); int getPort(); int getChunkSize(); String getServerAddress(); } static final byte INTERNAL_PROTOCOL_VERSION = 2; public static final int DEFAULT_BACKUP_PORT = 6362; // It's ok if there are more transactions, since these worker threads doesn't // do any actual work themselves, but spawn off other worker threads doing the // actual work. So this is more like a core Netty I/O pool worker size. public final static int DEFAULT_MAX_NUMBER_OF_CONCURRENT_TRANSACTIONS = 200; private ChannelFactory channelFactory; private ServerBootstrap bootstrap; private T requestTarget; private ChannelGroup channelGroup; private final Map<Channel, Pair<RequestContext, AtomicLong /*time last heard of*/>> connectedSlaveChannels = new ConcurrentHashMap<Channel, Pair<RequestContext, AtomicLong>>(); private ExecutorService executor; private ExecutorService workerExecutor; private ExecutorService targetCallExecutor; private StringLogger msgLog; private final Map<Channel, PartialRequest> partialRequests = new ConcurrentHashMap<Channel, PartialRequest>(); private Configuration config; private final int frameLength; private volatile boolean shuttingDown; // Executor for channels that we know should be finished, but can't due to being // active at the moment. private ExecutorService unfinishedTransactionExecutor; // This is because there's a bug in Netty causing some channelClosed/channelDisconnected // events to not be sent. This is merely a safety net to catch the remained of the closed // channels that netty doesn't tell us about. private ScheduledExecutorService silentChannelExecutor; private final byte applicationProtocolVersion; private long oldChannelThresholdMillis; private TxChecksumVerifier txVerifier; private int chunkSize; public Server( T requestTarget, Configuration config, StringLogger logger, int frameLength, byte applicationProtocolVersion, TxChecksumVerifier txVerifier ) { this.requestTarget = requestTarget; this.config = config; this.frameLength = frameLength; this.applicationProtocolVersion = applicationProtocolVersion; this.msgLog = logger; this.txVerifier = txVerifier; } @Override public void init() throws Throwable { } @Override public void start() throws Throwable { this.oldChannelThresholdMillis = config.getOldChannelThreshold(); chunkSize = config.getChunkSize(); assertChunkSizeIsWithinFrameSize( chunkSize, frameLength ); executor = Executors.newCachedThreadPool( new NamedThreadFactory( "Server receiving" ) ); workerExecutor = Executors.newCachedThreadPool( new NamedThreadFactory( "Server receiving" ) ); targetCallExecutor = Executors.newCachedThreadPool( new NamedThreadFactory( getClass().getSimpleName() + ":" + config.getPort() ) ); unfinishedTransactionExecutor = Executors.newScheduledThreadPool( 2, new NamedThreadFactory( "Unfinished transactions" ) ); channelFactory = new NioServerSocketChannelFactory( executor, workerExecutor, config.getMaxConcurrentTransactions() ); silentChannelExecutor = Executors.newSingleThreadScheduledExecutor( new NamedThreadFactory( "Silent channel " + "reaper" ) ); silentChannelExecutor.scheduleWithFixedDelay( silentChannelFinisher(), 5, 5, TimeUnit.SECONDS ); bootstrap = new ServerBootstrap( channelFactory ); bootstrap.setPipelineFactory( this ); Channel channel; InetSocketAddress socketAddress; if ( config.getServerAddress() == null ) { socketAddress = new InetSocketAddress( config.getPort() ); } else { socketAddress = new InetSocketAddress( config.getServerAddress(), config.getPort() ); } try { channel = bootstrap.bind( socketAddress ); } catch ( ChannelException e ) { msgLog.logMessage( "Failed to bind server to " + socketAddress, e ); executor.shutdown(); workerExecutor.shutdown(); throw new IOException( e ); } channelGroup = new DefaultChannelGroup(); channelGroup.add( channel ); msgLog.logMessage( getClass().getSimpleName() + " communication server started and bound to " + socketAddress, true ); } @Override public void stop() throws Throwable { // Close all open connections shuttingDown = true; silentChannelExecutor.shutdown(); unfinishedTransactionExecutor.shutdown(); targetCallExecutor.shutdown(); channelGroup.close().awaitUninterruptibly(); channelFactory.releaseExternalResources(); } @Override public void shutdown() throws Throwable { } private Runnable silentChannelFinisher() { // This poller is here because sometimes Netty doesn't tell us when channels are // closed or disconnected. Most of the time it does, but this acts as a safety // net for those we don't get notifications for. When the bug is fixed remove this. return new Runnable() { @Override public void run() { Map<Channel, Boolean/*starting to get old?*/> channels = new HashMap<Channel, Boolean>(); synchronized ( connectedSlaveChannels ) { for ( Map.Entry<Channel, Pair<RequestContext, AtomicLong>> channel : connectedSlaveChannels .entrySet() ) { // Has this channel been silent for a while? long age = System.currentTimeMillis() - channel.getValue().other().get(); if ( age > oldChannelThresholdMillis ) { msgLog.logMessage( "Found a silent channel " + channel + ", " + age ); channels.put( channel.getKey(), Boolean.TRUE ); } else if ( age > oldChannelThresholdMillis / 2 ) { // Then add it to a list to check channels.put( channel.getKey(), Boolean.FALSE ); } } } for ( Map.Entry<Channel, Boolean> channel : channels.entrySet() ) { if ( channel.getValue() || !channel.getKey().isOpen() || !channel.getKey().isConnected() || !channel.getKey().isBound() ) { tryToFinishOffChannel( channel.getKey() ); } } } }; } /** * Only exposed so that tests can control it. It's not configurable really. */ protected byte getInternalProtocolVersion() { return INTERNAL_PROTOCOL_VERSION; } public ChannelPipeline getPipeline() throws Exception { ChannelPipeline pipeline = Channels.pipeline(); addLengthFieldPipes( pipeline, frameLength ); pipeline.addLast( "serverHandler", new ServerHandler() ); return pipeline; } private class ServerHandler extends SimpleChannelHandler { @Override public void channelOpen( ChannelHandlerContext ctx, ChannelStateEvent e ) throws Exception { channelGroup.add( e.getChannel() ); } @Override public void messageReceived( ChannelHandlerContext ctx, MessageEvent event ) throws Exception { try { ChannelBuffer message = (ChannelBuffer) event.getMessage(); handleRequest( message, event.getChannel() ); } catch ( Throwable e ) { msgLog.logMessage( "Error handling request", e ); ctx.getChannel().close(); tryToFinishOffChannel( ctx.getChannel() ); throw Exceptions.launderedException( e ); } } @Override public void channelClosed( ChannelHandlerContext ctx, ChannelStateEvent e ) throws Exception { super.channelClosed( ctx, e ); if ( !ctx.getChannel().isOpen() ) { tryToFinishOffChannel( ctx.getChannel() ); } channelGroup.remove( e.getChannel() ); } @Override public void channelDisconnected( ChannelHandlerContext ctx, ChannelStateEvent e ) throws Exception { super.channelDisconnected( ctx, e ); if ( !ctx.getChannel().isConnected() ) { tryToFinishOffChannel( ctx.getChannel() ); } } @Override public void exceptionCaught( ChannelHandlerContext ctx, ExceptionEvent e ) throws Exception { msgLog.warn( "Exception from Netty", e.getCause() ); } } protected void tryToFinishOffChannel( Channel channel ) { Pair<RequestContext, AtomicLong> slave = null; synchronized ( connectedSlaveChannels ) { slave = connectedSlaveChannels.remove( channel ); } if ( slave == null ) { return; } tryToFinishOffChannel( channel, slave.first() ); } protected void tryToFinishOffChannel( Channel channel, RequestContext slave ) { try { finishOffChannel( channel, slave ); unmapSlave( channel, slave ); } catch ( Throwable failure ) // Unknown error trying to finish off the tx { submitSilent( unfinishedTransactionExecutor, newTransactionFinisher( slave ) ); if ( shouldLogFailureToFinishOffChannel( failure ) ) { msgLog.logMessage( "Could not finish off dead channel", failure ); } } } protected boolean shouldLogFailureToFinishOffChannel( Throwable failure ) { return true; } private void submitSilent( ExecutorService service, Runnable job ) { try { service.submit( job ); } catch ( RejectedExecutionException e ) { // Don't scream and shout if we're shutting down, because a rejected execution // is expected at that time. if ( !shuttingDown ) { throw e; } } } private Runnable newTransactionFinisher( final RequestContext slave ) { return new Runnable() { @Override public void run() { try { finishOffChannel( null, slave ); } catch ( Throwable e ) { // Introduce some delay here. it becomes like a busy wait if it never succeeds sleepNicely( 200 ); unfinishedTransactionExecutor.submit( newTransactionFinisher( slave ) ); } } private void sleepNicely( int millis ) { try { Thread.sleep( millis ); } catch ( InterruptedException e ) { Thread.interrupted(); } } }; } protected void handleRequest( ChannelBuffer buffer, final Channel channel ) throws IOException { Byte continuation = readContinuationHeader( buffer, channel ); if ( continuation == null ) { return; } if ( continuation == ChunkingChannelBuffer.CONTINUATION_MORE ) { PartialRequest partialRequest = partialRequests.get( channel ); if ( partialRequest == null ) { // This is the first chunk in a multi-chunk request RequestType<T> type = getRequestContext( buffer.readByte() ); RequestContext context = readContext( buffer ); ChannelBuffer targetBuffer = mapSlave( channel, context, type ); partialRequest = new PartialRequest( type, context, targetBuffer ); partialRequests.put( channel, partialRequest ); } partialRequest.add( buffer ); } else { PartialRequest partialRequest = partialRequests.remove( channel ); RequestType<T> type = null; RequestContext context = null; ChannelBuffer targetBuffer; ChannelBuffer bufferToReadFrom = null; ChannelBuffer bufferToWriteTo = null; if ( partialRequest == null ) { // This is the one and single chunk in the request type = getRequestContext( buffer.readByte() ); context = readContext( buffer ); targetBuffer = mapSlave( channel, context, type ); bufferToReadFrom = buffer; bufferToWriteTo = targetBuffer; } else { // This is the last chunk in a multi-chunk request type = partialRequest.type; context = partialRequest.context; targetBuffer = partialRequest.buffer; partialRequest.add( buffer ); bufferToReadFrom = targetBuffer; bufferToWriteTo = ChannelBuffers.dynamicBuffer(); } bufferToWriteTo.clear(); final ChunkingChannelBuffer chunkingBuffer = new ChunkingChannelBuffer( bufferToWriteTo, channel, chunkSize, getInternalProtocolVersion(), applicationProtocolVersion ); submitSilent( targetCallExecutor, targetCaller( type, channel, context, chunkingBuffer, bufferToReadFrom ) ); } } private Byte readContinuationHeader( ChannelBuffer buffer, final Channel channel ) { byte[] header = new byte[2]; buffer.readBytes( header ); try { // Read request header and assert correct internal/application protocol version assertSameProtocolVersion( header, getInternalProtocolVersion(), applicationProtocolVersion ); } catch ( final IllegalProtocolVersionException e ) { // Version mismatch, fail with a good exception back to the client final ChunkingChannelBuffer failureResponse = new ChunkingChannelBuffer( ChannelBuffers.dynamicBuffer(), channel, chunkSize, getInternalProtocolVersion(), applicationProtocolVersion ); submitSilent( targetCallExecutor, new Runnable() { @Override public void run() { writeFailureResponse( e, failureResponse ); } } ); return null; } return (byte) (header[0] & 0x1); } protected Runnable targetCaller( final RequestType<T> type, final Channel channel, final RequestContext context, final ChunkingChannelBuffer targetBuffer, final ChannelBuffer bufferToReadFrom ) { return new Runnable() { @SuppressWarnings("unchecked") public void run() { Response<R> response = null; try { response = type.getTargetCaller().call( requestTarget, context, bufferToReadFrom, targetBuffer ); type.getObjectSerializer().write( response.response(), targetBuffer ); writeStoreId( response.getStoreId(), targetBuffer ); writeTransactionStreams( response.transactions(), targetBuffer ); targetBuffer.done(); responseWritten( type, channel, context ); } catch ( Throwable e ) { targetBuffer.clear( true ); writeFailureResponse( e, targetBuffer ); tryToFinishOffChannel( channel, context ); throw Exceptions.launderedException( e ); } finally { if ( response != null ) { response.close(); } unmapSlave( channel, context ); } } }; } protected void writeFailureResponse( Throwable exception, ChunkingChannelBuffer buffer ) { try { ByteArrayOutputStream bytes = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream( bytes ); out.writeObject( exception ); out.close(); buffer.writeBytes( bytes.toByteArray() ); buffer.done(); } catch ( IOException e ) { msgLog.logMessage( "Couldn't send cause of error to client", exception ); } } protected void responseWritten( RequestType<T> type, Channel channel, RequestContext context ) { } private static void writeStoreId( StoreId storeId, ChannelBuffer targetBuffer ) { targetBuffer.writeBytes( storeId.serialize() ); } private static <T> void writeTransactionStreams( TransactionStream txStream, ChannelBuffer buffer ) throws IOException { if ( !txStream.hasNext() ) { buffer.writeByte( 0 ); return; } String[] datasources = txStream.dataSourceNames(); assert datasources.length <= 255 : "too many data sources"; buffer.writeByte( datasources.length ); Map<String, Integer> datasourceId = new HashMap<String, Integer>(); for ( int i = 0; i < datasources.length; i++ ) { String datasource = datasources[i]; writeString( buffer, datasource ); datasourceId.put( datasource, i + 1/*0 means "no more transactions"*/ ); } for ( Triplet<String, Long, TxExtractor> tx : IteratorUtil.asIterable( txStream ) ) { buffer.writeByte( datasourceId.get( tx.first() ) ); buffer.writeLong( tx.second() ); BlockLogBuffer blockBuffer = new BlockLogBuffer( buffer ); tx.third().extract( blockBuffer ); blockBuffer.done(); } buffer.writeByte( 0/*no more transactions*/ ); } protected RequestContext readContext( ChannelBuffer buffer ) { long sessionId = buffer.readLong(); int machineId = buffer.readInt(); int eventIdentifier = buffer.readInt(); int txsSize = buffer.readByte(); Tx[] lastAppliedTransactions = new Tx[txsSize]; Tx neoTx = null; for ( int i = 0; i < txsSize; i++ ) { String ds = readString( buffer ); Tx tx = RequestContext.lastAppliedTx( ds, buffer.readLong() ); lastAppliedTransactions[i] = tx; // Only perform checksum checks on the neo data source. if ( ds.equals( Config.DEFAULT_DATA_SOURCE_NAME ) ) { neoTx = tx; } } int masterId = buffer.readInt(); long checksum = buffer.readLong(); // Only perform checksum checks on the neo data source. If there's none in the request // then don't perform any such check. if ( neoTx != null ) { txVerifier.assertMatch( neoTx.getTxId(), masterId, checksum ); } return new RequestContext( sessionId, machineId, eventIdentifier, lastAppliedTransactions, masterId, checksum ); } protected abstract RequestType<T> getRequestContext( byte id ); protected ChannelBuffer mapSlave( Channel channel, RequestContext slave, RequestType<T> type ) { synchronized ( connectedSlaveChannels ) { // Checking for machineId -1 excludes the "empty" slave contexts // which some communication points pass in as context. if ( slave != null && slave.machineId() != RequestContext.EMPTY.machineId() ) { Pair<RequestContext, AtomicLong> previous = connectedSlaveChannels.get( channel ); if ( previous != null ) { previous.other().set( System.currentTimeMillis() ); } else { connectedSlaveChannels.put( channel, Pair.of( slave, new AtomicLong( System.currentTimeMillis() ) ) ); } } } return ChannelBuffers.dynamicBuffer(); } protected void unmapSlave( Channel channel, RequestContext slave ) { synchronized ( connectedSlaveChannels ) { connectedSlaveChannels.remove( channel ); } } protected T getRequestTarget() { return requestTarget; } protected abstract void finishOffChannel( Channel channel, RequestContext context ); public Map<Channel, RequestContext> getConnectedSlaveChannels() { Map<Channel, RequestContext> result = new HashMap<Channel, RequestContext>(); synchronized ( connectedSlaveChannels ) { for ( Map.Entry<Channel, Pair<RequestContext, AtomicLong>> entry : connectedSlaveChannels.entrySet() ) { result.put( entry.getKey(), entry.getValue().first() ); } } return result; } // ===================================================================== // Just some methods which aren't really used when running an HA cluster, // but exposed so that other tools can reach that information. // ===================================================================== private class PartialRequest { final RequestContext context; final ChannelBuffer buffer; final RequestType<T> type; public PartialRequest( RequestType<T> type, RequestContext context, ChannelBuffer buffer ) { this.type = type; this.context = context; this.buffer = buffer; } public void add( ChannelBuffer buffer ) { this.buffer.writeBytes( buffer ); } } }