/** * Copyright 2009 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package org.waveprotocol.wave.examples.fedone.rpc; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.inject.Inject; import com.google.inject.name.Named; import com.google.protobuf.Descriptors; import com.google.protobuf.Message; import com.google.protobuf.RpcCallback; import com.google.protobuf.Service; import com.google.protobuf.UnknownFieldSet; import com.google.protobuf.Descriptors.MethodDescriptor; import org.waveprotocol.wave.examples.fedone.util.Log; import org.eclipse.jetty.server.Connector; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.nio.SelectChannelConnector; import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletHandler; import org.eclipse.jetty.servlet.ServletHolder; import org.eclipse.jetty.websocket.WebSocket; import org.eclipse.jetty.websocket.WebSocketServlet; import java.io.IOException; import java.net.SocketAddress; import java.net.InetSocketAddress; import java.nio.channels.ClosedChannelException; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import javax.servlet.http.HttpServletRequest; /** * ServerRpcProvider can provide instances of type Service over an incoming * network socket and service incoming RPCs to these services and their methods. * * */ public class ServerRpcProvider { private static final Log LOG = Log.get(ServerRpcProvider.class); private final SocketAddress rpcHostingAddress; private final String websocketHost; private final Integer websocketPort; private final Set<Connection> incomingConnections = Sets.newHashSet(); private final ExecutorService threadPool; private ServerSocketChannel rpcServer = null; private Server websocketServer = null; private Future<?> acceptorThread = null; // Mapping from incoming protocol buffer type -> specific handler. private final Map<Descriptors.Descriptor, RegisteredServiceMethod> registeredServices = Maps.newHashMap(); /** * Internal, static container class for any specific registered service method. */ static class RegisteredServiceMethod { final Service service; final MethodDescriptor method; RegisteredServiceMethod(Service service, MethodDescriptor method) { this.service = service; this.method = method; } } class SequencedProtoChannelConnection extends Connection { private final SequencedProtoChannel protoChannel; private final SocketChannel channel; SequencedProtoChannelConnection(SocketChannel channel) { this.channel = channel; LOG.info("New Connection set up from " + this.channel); // Set up protoChannel, let it know to expect messages of all the // registered service/method types. // TODO: dynamic lookup for these types instead protoChannel = new SequencedProtoChannel(channel, this, threadPool); expectMessages(protoChannel); protoChannel.startAsyncRead(); } protected void sendMessage(long sequenceNo, Message message) { protoChannel.sendMessage(sequenceNo, message); } } class WebSocketConnection extends Connection { private WebSocketServerChannel socketChannel; WebSocketConnection() { socketChannel = new WebSocketServerChannel(this); LOG.info("New websocket connection set up."); expectMessages(socketChannel); } protected void sendMessage(long sequenceNo, Message message) { socketChannel.sendMessage(sequenceNo, message); } public WebSocketServerChannel getWebSocketServerChannel() { return socketChannel; } } abstract class Connection implements ProtoCallback { private final Map<Long, ServerRpcController> activeRpcs = new ConcurrentHashMap<Long, ServerRpcController>(); protected void expectMessages(MessageExpectingChannel channel) { synchronized (registeredServices) { for (RegisteredServiceMethod serviceMethod : registeredServices.values()) { channel.expectMessage(serviceMethod.service.getRequestPrototype(serviceMethod.method)); LOG.fine("Expecting: " + serviceMethod.method.getFullName()); } } channel.expectMessage(Rpc.CancelRpc.getDefaultInstance()); } protected abstract void sendMessage(long sequenceNo, Message message); @Override public void message(final long sequenceNo, Message message) { if (message instanceof Rpc.CancelRpc) { final ServerRpcController controller = activeRpcs.get(sequenceNo); if (controller == null) { throw new IllegalStateException("Trying to cancel an RPC that is not active!"); } else { LOG.info("Cancelling open RPC " + sequenceNo); controller.cancel(); } } else if (registeredServices.containsKey(message.getDescriptorForType())) { if (activeRpcs.containsKey(sequenceNo)) { throw new IllegalStateException( "Can't invoke a new RPC with a sequence number already in use."); } else { final RegisteredServiceMethod serviceMethod = registeredServices.get(message.getDescriptorForType()); // Create the internal ServerRpcController used to invoke the call. final ServerRpcController controller = new ServerRpcController(message, serviceMethod.service, serviceMethod.method, new RpcCallback<Message>() { @Override synchronized public void run(Message message) { if (message instanceof Rpc.RpcFinished || !serviceMethod.method.getOptions().getExtension(Rpc.isStreamingRpc)) { // This RPC is over - remove it from the map. boolean failed = message instanceof Rpc.RpcFinished ? ((Rpc.RpcFinished) message).getFailed() : false; LOG.fine("RPC " + sequenceNo + " is now finished, failed = " + failed); if (failed) { LOG.info("error = " + ((Rpc.RpcFinished) message).getErrorText()); } activeRpcs.remove(sequenceNo); } sendMessage(sequenceNo, message); } }); // Kick off a new thread specific to this RPC. activeRpcs.put(sequenceNo, controller); threadPool.execute(controller); } } else { // Sent a message type we understand, but don't expect - erronous case! throw new IllegalStateException("Got expected but unknown message (" + message + ") for sequence: " + sequenceNo); } } @Override public void unknown(long sequenceNo, String messageType, UnknownFieldSet message) { throw new IllegalStateException("Got unknown message (type: " + messageType + ", " + message + ") for sequence: " + sequenceNo); } @Override public void unknown(long sequenceNo, String messageType, String message) { throw new IllegalStateException("Got unknown message (type: " + messageType + ", " + message + ") for sequence: " + sequenceNo); } } /** * Construct a new ServerRpcProvider, hosting on the passed SocketAddress * and WebSocket host and port. (The websocket isn't passed in as a * SocketAddress beacuse Jetty requires host + port.) * * Also accepts an ExecutorService for spawning managing threads. * * @param rpcHost the hosting socket * @param websocketHost host for websocket server * @param websocketPort port for websocket server * @param threadPool the service used to create threads */ public ServerRpcProvider(SocketAddress rpcHost, String websocketHost, Integer websocketPort, ExecutorService threadPool) { rpcHostingAddress = rpcHost; this.websocketHost = websocketHost; this.websocketPort = websocketPort; this.threadPool = threadPool; } /** * Constructs a new ServerRpcProvider with a default ExecutorService. */ public ServerRpcProvider(SocketAddress rpcHost, String websocketHost, Integer websocketPort) { this(rpcHost, websocketHost, websocketPort, Executors.newCachedThreadPool()); } @Inject public ServerRpcProvider(@Named("client_frontend_hostname") String rpcHost, @Named("client_frontend_port") Integer rpcPort, @Named("websocket_frontend_hostname") String websocketHost, @Named("websocket_frontend_port") Integer websocketPort) { this(new InetSocketAddress(rpcHost, rpcPort), websocketHost, websocketPort); } /** * Starts this server, binding to the previously passed SocketAddress. */ public void startRpcServer() throws IOException { rpcServer = ServerSocketChannel.open(); rpcServer.socket().setReuseAddress(true); rpcServer.socket().bind(rpcHostingAddress); rpcServer.configureBlocking(true); // Spawn a new server acceptor thread, which must accept incoming // connections indefinitely - until a ClosedChannelException is thrown. acceptorThread = threadPool.submit(new Runnable() { @Override public void run() { try { LOG.fine("ServerRpcProvider acceptorThread waiting for connections."); while (true) { SocketChannel serverSocket = rpcServer.accept(); incomingConnections.add(new SequencedProtoChannelConnection(serverSocket)); } } catch (ClosedChannelException e) { return; } catch (IOException e) { throw new IllegalStateException("Server should not throw a misunderstood IOException", e); } } }); } public void startWebSocketServer() { websocketServer = new Server(); Connector c = new SelectChannelConnector(); c.setHost(websocketHost); c.setPort(websocketPort); websocketServer.addConnector(c); ServletContextHandler context = new ServletContextHandler(); context.setContextPath("/"); websocketServer.setHandler(context); ServletHolder holder = new ServletHolder(new WaveWebSocketServlet()); holder.setInitParameter("bufferSize", ""+1024*1024); // 1M buffer. TODO(zamfi): fix to let messages span frames. holder.setInitParameter("maxIdleTime", "-1"); context.addServlet(holder, "/"); try { websocketServer.start(); } catch (Exception e) { // yes, .start() throws "Exception" LOG.severe("Fatal error starting websocket server.", e); return; } LOG.fine("WebSocket server running."); } public class WaveWebSocketServlet extends WebSocketServlet { protected WebSocket doWebSocketConnect(HttpServletRequest request, String protocol) { WebSocketConnection connection = new WebSocketConnection(); return connection.getWebSocketServerChannel(); } } /** * Returns the bound socket. This is null if this server is not running. */ public SocketAddress getBoundAddress() { return rpcServer != null ? rpcServer.socket().getLocalSocketAddress() : null; } /** * Returns the socket the WebSocket server is listening on. */ public SocketAddress getWebSocketAddress() { if (websocketServer == null) { return null; } else { Connector c = websocketServer.getConnectors()[0]; return new InetSocketAddress(c.getHost(), c.getLocalPort()); } } /** * Stops this server. */ public void stopServer() throws IOException { if (rpcServer != null) { rpcServer.close(); } try { websocketServer.stop(); // yes, .stop() throws "Exception" } catch (Exception e) { LOG.warning("Fatal error stopping websocket server.", e); } if (acceptorThread != null) { try { acceptorThread.get(); } catch (InterruptedException e) { throw new IllegalStateException(); } catch (ExecutionException e) { throw new IllegalStateException("Server thread threw an exception", e.getCause()); } if (!acceptorThread.isDone()) { throw new IllegalStateException("Server acceptor thread has not stopped."); } } LOG.fine("server shutdown."); } /** * Register all methods provided by the given service type. */ public void registerService(Service service) { synchronized (registeredServices) { for (MethodDescriptor methodDescriptor : service.getDescriptorForType().getMethods()) { registeredServices.put(methodDescriptor.getInputType(), new RegisteredServiceMethod(service, methodDescriptor)); } } } }