package edu.brown.protorpc; import java.io.IOException; import java.net.ConnectException; import java.net.InetSocketAddress; import java.nio.channels.SelectableChannel; import java.nio.channels.SocketChannel; import java.util.HashMap; import org.apache.log4j.Logger; import ca.evanjones.protorpc.Protocol; import ca.evanjones.protorpc.Protocol.RpcRequest; import ca.evanjones.protorpc.Protocol.RpcResponse; import com.google.protobuf.Descriptors; import com.google.protobuf.Message; import com.google.protobuf.RpcCallback; import com.google.protobuf.RpcChannel; import com.google.protobuf.RpcController; import edu.brown.net.NonBlockingConnection; public class ProtoRpcChannel extends AbstractEventHandler implements RpcChannel { private static final Logger LOG = Logger.getLogger(ProtoRpcChannel.class); private final EventLoop eventLoop; private final ConnectFactory connector; private int sequence; private ProtoConnection connection; private final HashMap<Integer, ProtoRpcController> pendingRpcs = new HashMap<Integer, ProtoRpcController>(); private int reconnectIntervalSeconds; /** A factory interface for connecting to an RPC server. */ public interface ConnectFactory { /** Creates a new connection that is connecting. */ public NonBlockingConnection startNewConnection(); } public ProtoRpcChannel(EventLoop eventLoop, ConnectFactory connector) { this.eventLoop = eventLoop; this.connector = connector; startAsyncConnect(); } private void startAsyncConnect() { assert connection == null; connection = new ProtoConnection(connector.startNewConnection()); if (connection.getChannel() != null && !((SocketChannel) connection.getChannel()).isConnected()) { eventLoop.registerConnect((SocketChannel) connection.getChannel(), this); } else { eventLoop.registerRead(connection.getChannel(), this); } } private static final class NIOConnectFactory implements ConnectFactory { private final InetSocketAddress address; public NIOConnectFactory(InetSocketAddress address) { this.address = address; } @Override public NonBlockingConnection startNewConnection() { try { SocketChannel socket = SocketChannel.open(); NonBlockingConnection connection = new NonBlockingConnection(socket); // this connect is non-blocking and should always return false. boolean finished = ((SocketChannel) connection.getChannel()).connect(address); if (finished) { throw new IllegalStateException("async connect finished instantly?"); } return connection; } catch (IOException e) { throw new RuntimeException(e); } } } public ProtoRpcChannel(EventLoop eventLoop, InetSocketAddress address) { this(eventLoop, new NIOConnectFactory(address)); } private static final class StaticConnectFactory implements ConnectFactory { private NonBlockingConnection connection; public StaticConnectFactory(NonBlockingConnection connection) { this.connection = connection; } @Override public NonBlockingConnection startNewConnection() { NonBlockingConnection out = connection; connection = null; return out; } } /** * Sets the number of seconds to wait before reconnecting, if the connect fails. * This permits a channel to be created without the server running. * * @param reconnectSeconds number of seconds to wait between reconnect attempts. 0 disables * reconnects (default). */ public void setReconnectInterval(int reconnectSeconds) { assert reconnectSeconds >= 0; reconnectIntervalSeconds = reconnectSeconds; } public void callMethod(Descriptors.MethodDescriptor method, RpcController controller, Message request, Message responsePrototype, RpcCallback<Message> done) { ProtoRpcController rpc = (ProtoRpcController) controller; rpc.startRpc(eventLoop, responsePrototype.newBuilderForType(), done); if (connection == null) { // closed connection: fail the RPC rpc.finishRpcFailure(Protocol.Status.ERROR_COMMUNICATION, "Connection closed"); return; } // Package up the request and send it final boolean debug = LOG.isDebugEnabled(); synchronized (this) { pendingRpcs.put(sequence, rpc); // System.err.println("Sending RPC sequence " + sequence); RpcRequest rpcRequest = makeRpcRequest(sequence, method, request); sequence += 1; boolean blocked = connection.tryWrite(rpcRequest); if (blocked) { // the write blocked: wait for write callbacks if (debug) LOG.debug("registering write with eventLoop: " + eventLoop); eventLoop.registerWrite(connection.getChannel(), this); } if (debug) LOG.debug(String.format("%d: Sending RPC %s sequence %d blocked = %b", hashCode(), method.getFullName(), sequence, blocked)); } } public static RpcRequest makeRpcRequest( int sequence, Descriptors.MethodDescriptor method, Message request) { RpcRequest.Builder requestBuilder = RpcRequest.newBuilder(); requestBuilder.setSequenceNumber(sequence); requestBuilder.setMethodName(method.getFullName()); requestBuilder.setRequest(request.toByteString()); return requestBuilder.build(); } @Override public void readCallback(SelectableChannel channel) { boolean isOpen = connection.readAllAvailable(); if (!isOpen) { // TODO: Fail any subsequent RPCs throw new UnsupportedOperationException("Connection closed: not handled (for now)."); } while (true) { RpcResponse.Builder builder = RpcResponse.newBuilder(); boolean success = connection.readBufferedMessage(builder); if (!success) { // TODO: Cache the builder object to reduce garbage? break; } // Set the appropriate flags on the RPC object // TODO: Handle bad sequence number by ignoring/logging? RpcResponse response = builder.build(); ProtoRpcController rpc = null; synchronized (this) { rpc = pendingRpcs.remove(response.getSequenceNumber()); assert response.getStatus() == Protocol.Status.OK; assert rpc != null : "No ProtoRpcController for Sequence# " + response.getSequenceNumber(); } rpc.finishRpcSuccess(response.getResponse()); } } @Override public void connectCallback(SocketChannel channel) { assert channel == connection.getChannel(); try { boolean connected = channel.finishConnect(); assert connected; } catch (ConnectException e) { // If the connection failed for some remote reason (timeout, connection refused) close(); // Reconnection disabled: throw the connect exception if (reconnectIntervalSeconds == 0) { throw new RuntimeException(e); } assert reconnectIntervalSeconds > 0; // We are supposed to reconnect: re-create the connection and schedule a reconnect. eventLoop.registerTimer(reconnectIntervalSeconds * 1000, this); return; } catch (IOException e) { throw new RuntimeException(e); } // register for read events on this connection eventLoop.registerRead(channel, this); boolean blocked = connection.writeAvailable(); if (blocked) { eventLoop.registerWrite(connection.getChannel(), this); } } // This is synchronized so it doesn't screw up a simultaneous callMethod() @Override public synchronized boolean writeCallback(SelectableChannel channel) { boolean blocked = connection.writeAvailable(); if (LOG.isDebugEnabled()) LOG.debug(String.format("%d: writeCallback blocked = %b", hashCode(), blocked)); return blocked; } @Override public void timerCallback() { assert reconnectIntervalSeconds > 0; startAsyncConnect(); } public void close() { if (connection == null) throw new IllegalStateException("connection closed"); connection.close(); connection = null; // Fail all pending RPCs for (ProtoRpcController rpc : pendingRpcs.values()) { // TODO: Define constants shared between C++ and Java? rpc.finishRpcFailure(Protocol.Status.ERROR_COMMUNICATION, "Connection closed"); } pendingRpcs.clear(); } private static final int RECONNECT_TIMEOUT_MS = 2000; static final int TOTAL_CONNECT_TIMEOUT_MS = 30000; public static ProtoRpcChannel[] connectParallel(final EventLoop eventLoop, final InetSocketAddress[] addresses) { return connectParallel(eventLoop, addresses, TOTAL_CONNECT_TIMEOUT_MS); } public static ProtoRpcChannel[] connectParallel(final EventLoop eventLoop, final InetSocketAddress[] addresses, final int total_time) { class ExitLoopHandler extends AbstractEventHandler { @Override public void timerCallback() { assert barrierCount > 0; ((NIOEventLoop) eventLoop).exitLoop(); } public void connectFinished() { barrierCount -= 1; assert barrierCount >= 0; if (barrierCount == 0) { eventLoop.cancelTimer(this); ((NIOEventLoop) eventLoop).exitLoop(); } } private int barrierCount = addresses.length; } final ExitLoopHandler exitLoopHandler = new ExitLoopHandler(); class ConnectHandler extends AbstractEventHandler { public ConnectHandler(int index) { this.index = index; startConnect(); } private void startConnect() { try { channel = SocketChannel.open(); channel.configureBlocking(false); // this connect is non-blocking and should always return false. boolean finished = channel.connect(addresses[index]); if (finished) { throw new IllegalStateException("async connect finished instantly?"); } eventLoop.registerConnect(channel, this); } catch (IOException e) { // TODO Auto-generated catch block throw new RuntimeException(e); } } @Override public void connectCallback(SocketChannel channel) { try { boolean finished = channel.finishConnect(); assert finished; exitLoopHandler.connectFinished(); } catch (ConnectException e) { // Some connection error occurred: retry after a timeout channel = null; eventLoop.registerTimer(RECONNECT_TIMEOUT_MS, this); } catch (IOException e) { throw new RuntimeException(e); } } @Override public void timerCallback() { // reattempt the connection startConnect(); } final int index; SocketChannel channel; } ConnectHandler[] channels = new ConnectHandler[addresses.length]; for (int i = 0; i < channels.length; ++i) { channels[i] = new ConnectHandler(i); } eventLoop.registerTimer(total_time, exitLoopHandler); eventLoop.run(); if (exitLoopHandler.barrierCount == 0) { ProtoRpcChannel[] rpcChannels = new ProtoRpcChannel[addresses.length]; for (int i = 0; i < channels.length; ++i) { rpcChannels[i] = new ProtoRpcChannel(eventLoop, new StaticConnectFactory(new NonBlockingConnection(channels[i].channel))); } return rpcChannels; } else { // Close any open channels in case connects are pending for (ConnectHandler connectHandler : channels) { if (connectHandler.channel == null) { try { connectHandler.channel.close(); } catch (IOException e) { throw new RuntimeException(e); } } } throw new RuntimeException("some connection failed after " + total_time / 1000 + " seconds"); } } }