package edu.brown.protorpc; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.channels.SelectableChannel; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import org.apache.log4j.Logger; import ca.evanjones.protorpc.Protocol; import ca.evanjones.protorpc.Protocol.RpcRequest; import ca.evanjones.protorpc.Protocol.RpcResponse; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import com.google.protobuf.RpcCallback; import com.google.protobuf.RpcController; import com.google.protobuf.Service; import edu.brown.net.NonBlockingConnection; public class ProtoServer extends AbstractEventHandler { private static final Logger LOG = Logger.getLogger(ProtoServer.class); public ProtoServer(EventLoop eventLoop) { this.eventLoop = eventLoop; } @Override public void acceptCallback(SelectableChannel channel) { // accept the connection assert channel == serverSocket; SocketChannel client; try { client = serverSocket.accept(); } catch (IOException e) { throw new RuntimeException(e); } assert client != null; // wrap it in a message connection and register with event loop ProtoConnection connection = new ProtoConnection(new NonBlockingConnection(client)); eventLoop.registerRead(client, new EventCallbackWrapper(connection)); // SelectionKey clientKey = connection.register(selector); // clientKey.attach(connection); // eventQueue.add(new Event(connection, null)); } private class EventCallbackWrapper extends AbstractEventHandler { public EventCallbackWrapper(ProtoConnection connection) { this.connection = connection; } @Override public void readCallback(SelectableChannel channel) { read(this); } @Override public synchronized boolean writeCallback(SelectableChannel channel) { return connection.writeAvailable(); } private final ProtoConnection connection; public synchronized void writeResponse(RpcResponse output) { boolean blocked = connection.tryWrite(output); if (blocked) { // write blocked: wait for the write callback eventLoop.registerWrite(connection.getChannel(), this); } } } private void read(EventCallbackWrapper eventLoopCallback) { boolean isOpen = eventLoopCallback.connection.readAllAvailable(); if (!isOpen) { // connection closed LOG.debug("Connection closed"); eventLoopCallback.connection.close(); return; } while (true) { RpcRequest.Builder requestBuilder = RpcRequest.newBuilder(); boolean hasMessage = eventLoopCallback.connection.readBufferedMessage(requestBuilder); if (!hasMessage) { break; } RpcRequest request = requestBuilder.build(); // System.out.println(request.getMethodName() + " " + request.getRequest().size()); // Handle the request ProtoMethodInvoker invoker = serviceRegistry.getInvoker(request.getMethodName()); // TODO: Reuse callback objects? ProtoServerCallback callback = new ProtoServerCallback(eventLoopCallback, request.getSequenceNumber()); try { invoker.invoke(callback.controller, request.getRequest(), callback); } catch (InvalidProtocolBufferException e) { throw new RuntimeException(e); } } } private static final class ProtoServerController implements RpcController { @Override public String errorText() { throw new UnsupportedOperationException("TODO: implement"); } @Override public boolean failed() { throw new UnsupportedOperationException("TODO: implement"); } @Override public boolean isCanceled() { throw new UnsupportedOperationException("TODO: implement"); } @Override public void notifyOnCancel(RpcCallback<Object> callback) { throw new UnsupportedOperationException("TODO: implement"); } @Override public void reset() { throw new UnsupportedOperationException("TODO: implement"); } @Override public void setFailed(String reason) { if (reason == null) { throw new NullPointerException("reason parameter must not be null"); } if (status != Protocol.Status.OK) { throw new IllegalStateException("RPC already failed or returned"); } assert errorReason == null; status = Protocol.Status.ERROR_USER; errorReason = reason; } @Override public void startCancel() { throw new UnsupportedOperationException("TODO: implement"); } private Protocol.Status status = Protocol.Status.OK; private String errorReason; } private static final class ProtoServerCallback implements RpcCallback<Message> { private final ProtoServerController controller = new ProtoServerController(); private EventCallbackWrapper eventLoopCallback; private final int sequence; public ProtoServerCallback(EventCallbackWrapper eventLoopCallback, int sequence) { this.eventLoopCallback = eventLoopCallback; this.sequence = sequence; assert this.eventLoopCallback != null; } @Override public void run(Message response) { if (eventLoopCallback == null) { throw new IllegalStateException("response callback must only be called once"); } RpcResponse.Builder responseMessage = RpcResponse.newBuilder(); responseMessage.setSequenceNumber(sequence); assert controller.status != Protocol.Status.INVALID; responseMessage.setStatus(controller.status); if (response != null) { responseMessage.setResponse(response.toByteString()); } else { // No message: we must have failed assert controller.status != Protocol.Status.OK; } if (controller.errorReason != null) { assert controller.status != Protocol.Status.OK; responseMessage.setErrorReason(controller.errorReason); } eventLoopCallback.writeResponse(responseMessage.build()); eventLoopCallback = null; } } public void bind(int port) { try { serverSocket = ServerSocketChannel.open(); // Avoid TIME_WAIT when killing the server serverSocket.socket().setReuseAddress(true); // Mac OS X: bind() before calling Selector.register or you don't get accept() events serverSocket.socket().bind(new InetSocketAddress(port)); eventLoop.registerAccept(serverSocket, this); } catch (IOException e) { throw new RuntimeException("Failed to bind socket on port #" + port, e); } } public void close() { try { serverSocket.close(); } catch (IOException e) { throw new RuntimeException(e); } } public void setServerSocketForTest(ServerSocketChannel serverSocket) { this.serverSocket = serverSocket; } public void register(Service service) { serviceRegistry.register(service); } private EventLoop eventLoop; private ServerSocketChannel serverSocket; private final ServiceRegistry serviceRegistry = new ServiceRegistry(); }