package org.marketcetera.util.rpc; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelOption; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import java.io.StringReader; import java.io.StringWriter; import java.util.List; import java.util.Map; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.PostConstruct; import javax.annotation.PreDestroy; import javax.xml.bind.JAXBContext; import javax.xml.bind.JAXBException; import javax.xml.bind.Marshaller; import javax.xml.bind.Unmarshaller; import org.apache.commons.lang.Validate; import org.marketcetera.util.log.SLF4JLoggerProxy; import org.marketcetera.util.misc.ClassVersion; import org.marketcetera.util.ws.ContextClassProvider; import org.marketcetera.util.ws.stateful.Authenticator; import org.marketcetera.util.ws.stateful.SessionHolder; import org.marketcetera.util.ws.stateful.SessionManager; import org.marketcetera.util.ws.stateless.StatelessClientContext; import org.marketcetera.util.ws.tags.AppId; import org.marketcetera.util.ws.tags.NodeId; import org.marketcetera.util.ws.tags.SessionId; import org.marketcetera.util.ws.tags.VersionId; import org.marketcetera.util.ws.wrappers.LocaleWrapper; import org.springframework.context.Lifecycle; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.protobuf.BlockingService; import com.googlecode.protobuf.pro.duplex.PeerInfo; import com.googlecode.protobuf.pro.duplex.execute.RpcServerCallExecutor; import com.googlecode.protobuf.pro.duplex.execute.ThreadPoolCallExecutor; import com.googlecode.protobuf.pro.duplex.server.DuplexTcpServerPipelineFactory; import com.googlecode.protobuf.pro.duplex.util.RenamingThreadFactoryProxy; /* $License$ */ /** * Provides RPC services. * * @author <a href="mailto:colin@marketcetera.com">Colin DuPlantis</a> * @version $Id: RpcServer.java 16901 2014-05-11 16:14:11Z colin $ * @since 2.4.0 */ @ClassVersion("$Id: RpcServer.java 16901 2014-05-11 16:14:11Z colin $") public class RpcServer<SessionClazz> implements Lifecycle,RpcServerServices<SessionClazz> { /* (non-Javadoc) * @see org.springframework.context.Lifecycle#isRunning() */ @Override public boolean isRunning() { return running.get(); } /* (non-Javadoc) * @see org.springframework.context.Lifecycle#start() */ @Override @PostConstruct public synchronized void start() { Validate.notNull(hostname); Validate.isTrue(port > 0 && port < 65536); Validate.notNull(sessionManager); Validate.notNull(authenticator); Validate.isTrue(threadPoolCore > 0); Validate.isTrue(threadPoolMax > 0); Validate.isTrue(threadPoolMax >= threadPoolCore); Validate.isTrue(sendBufferSize > 0); Validate.isTrue(receiveBufferSize > 0); Validate.notEmpty(serviceSpecs); Messages.SERVER_STARTING.info(this, hostname, port); if(isRunning()) { stop(); } try { reportContext = JAXBContext.newInstance(contextClassProvider==null?new Class<?>[0]:contextClassProvider.getContextClasses()); marshaller = reportContext.createMarshaller(); unmarshaller = reportContext.createUnmarshaller(); } catch (JAXBException e) { SLF4JLoggerProxy.error(this, e); throw new RuntimeException(e); } PeerInfo serverInfo = new PeerInfo(getRpcHostname(), getRpcPort()); executor = new ThreadPoolCallExecutor(threadPoolCore, threadPoolMax); DuplexTcpServerPipelineFactory serverFactory = new DuplexTcpServerPipelineFactory(serverInfo); serverFactory.setRpcServerCallExecutor(executor); ServerBootstrap bootstrap = new ServerBootstrap(); bootstrap.group(new NioEventLoopGroup(0, new RenamingThreadFactoryProxy("boss", Executors.defaultThreadFactory())), new NioEventLoopGroup(0, new RenamingThreadFactoryProxy("worker", Executors.defaultThreadFactory()))); bootstrap.channel(NioServerSocketChannel.class); bootstrap.childHandler(serverFactory); bootstrap.localAddress(serverInfo.getPort()); bootstrap.option(ChannelOption.SO_SNDBUF, sendBufferSize); bootstrap.option(ChannelOption.SO_RCVBUF, receiveBufferSize); bootstrap.childOption(ChannelOption.SO_RCVBUF, receiveBufferSize); bootstrap.childOption(ChannelOption.SO_SNDBUF, sendBufferSize); bootstrap.option(ChannelOption.TCP_NODELAY, noDelay); for(RpcServiceSpec<SessionClazz> serviceSpec : serviceSpecs) { serviceSpec.setRpcServerServices(this); BlockingService activeService = serviceSpec.generateService(); serverFactory.getRpcServiceRegistry().registerService(activeService); Messages.SERVICE_STARTING.info(this, serviceSpec.getDescription()); } channelToken = bootstrap.bind(); while(!channelToken.isDone()) { try { Thread.sleep(250); } catch (InterruptedException e) { throw new RuntimeException(e); } } // TODO throw exception? running.set(channelToken.isSuccess()); //RpcClientConnectionRegistry clientRegistry = new RpcClientConnectionRegistry(); //serverFactory.registerConnectionEventListener(clientRegistry); } /* (non-Javadoc) * @see org.springframework.context.Lifecycle#stop() */ @Override @PreDestroy public synchronized void stop() { Messages.SERVER_STOPPING.info(this); try { try { if(executor != null) { executor.shutdownNow(); } } catch (Exception ignored) {} try { if(channelToken != null && channelToken.channel() != null) { channelToken.channel().close(); } } catch (Exception ignored) {} for(SessionId session : rpcSessions.keySet()) { try { sessionManager.remove(session); } catch (Exception ignored) {} } } finally { rpcSessions.clear(); channelToken = null; executor = null; reportContext = null; marshaller = null; unmarshaller = null; running.set(false); } } /** * Get the serviceSpecs value. * * @return a <code>List<RpcServiceSpec></code> value */ public List<RpcServiceSpec<SessionClazz>> getServiceSpecs() { return serviceSpecs; } /** * Sets the serviceSpecs value. * * @param inServiceSpecs a <code>List<RpcServiceSpec></code> value */ public void setServiceSpecs(List<RpcServiceSpec<SessionClazz>> inServiceSpecs) { serviceSpecs.clear(); if(inServiceSpecs != null) { serviceSpecs.addAll(inServiceSpecs); } } /** * Get the sendBufferSize value. * * @return an <code>int</code> value */ public int getSendBufferSize() { return sendBufferSize; } /** * Sets the sendBufferSize value. * * @param inSendBufferSize an <code>int</code> value */ public void setSendBufferSize(int inSendBufferSize) { sendBufferSize = inSendBufferSize; } /** * Get the receiveBufferSize value. * * @return an <code>int</code> value */ public int getReceiveBufferSize() { return receiveBufferSize; } /** * Sets the receiveBufferSize value. * * @param inReceiveBufferSize an <code>int</code> value */ public void setReceiveBufferSize(int inReceiveBufferSize) { receiveBufferSize = inReceiveBufferSize; } /** * Get the noDelay value. * * @return a <code>boolean</code> value */ public boolean getNoDelay() { return noDelay; } /** * Sets the noDelay value. * * @param inNoDelay a <code>boolean</code> value */ public void setNoDelay(boolean inNoDelay) { noDelay = inNoDelay; } /** * Get the threadPoolCore value. * * @return an <code>int</code> value */ public int getThreadPoolCore() { return threadPoolCore; } /** * Sets the threadPoolCore value. * * @param inThreadPoolCore an <code>int</code> value */ public void setThreadPoolCore(int inThreadPoolCore) { threadPoolCore = inThreadPoolCore; } /** * Get the threadPoolMax value. * * @return an <code>int</code> value */ public int getThreadPoolMax() { return threadPoolMax; } /** * Sets the threadPoolMax value. * * @param inThreadPoolMax an <code>int</code> value */ public void setThreadPoolMax(int inThreadPoolMax) { threadPoolMax = inThreadPoolMax; } /** * Get the rpcHostname value. * * @return a <code>String</code> value */ public String getRpcHostname() { return hostname; } /** * Sets the rpcHostname value. * * @param inRpcHostname a <code>String</code> value */ public void setHostname(String inRpcHostname) { hostname = inRpcHostname; } /** * Get the rpcPort value. * * @return an <code>int</code> value */ public int getRpcPort() { return port; } /** * Sets the rpcPort value. * * @param inRpcPort an <code>int</code> value */ public void setPort(int inRpcPort) { port = inRpcPort; } /** * Get the authenticator value. * * @return an <code>Authenticator</code> value */ public Authenticator getAuthenticator() { return authenticator; } /** * Sets the authenticator value. * * @param inAuthenticator an <code>Authenticator</code> value */ public void setAuthenticator(Authenticator inAuthenticator) { authenticator = inAuthenticator; } /** * Get the sessionManager value. * * @return a <code>SessionManager<SessionClazz></code> value */ public SessionManager<SessionClazz> getSessionManager() { return sessionManager; } /** * Sets the sessionManager value. * * @param inSessionManager a <code>SessionManager<SessionClazz></code> value */ public void setSessionManager(SessionManager<SessionClazz> inSessionManager) { sessionManager = inSessionManager; } /** * Get the contextClassProvider value. * * @return a <code>ContextClassProvider</code> value */ public ContextClassProvider getContextClassProvider() { return contextClassProvider; } /** * Sets the contextClassProvider value. * * @param inContextClassProvider a <code>ContextClassProvider</code> value */ public void setContextClassProvider(ContextClassProvider inContextClassProvider) { contextClassProvider = inContextClassProvider; } /* (non-Javadoc) * @see org.marketcetera.client.rpc.RpcServerServices#login(org.marketcetera.client.rpc.Credentials) */ @Override public SessionId login(RpcCredentials inCredentials) { StatelessClientContext context = new StatelessClientContext(); context.setAppId(new AppId(inCredentials.getAppId())); context.setClientId(new NodeId(inCredentials.getClientId())); context.setVersionId(new VersionId(inCredentials.getVersionId())); LocaleWrapper locale = new LocaleWrapper(inCredentials.getLocale()); context.setLocale(locale); authenticator.shouldAllow(context, inCredentials.getUsername(), inCredentials.getPassword().toCharArray()); SessionId sessionId = SessionId.generate(); SessionHolder<SessionClazz> sessionHolder = new SessionHolder<SessionClazz>(inCredentials.getUsername(), context); sessionManager.put(sessionId, sessionHolder); rpcSessions.put(sessionId, inCredentials.getUsername()); return sessionId; } /* (non-Javadoc) * @see org.marketcetera.client.rpc.RpcServerServices#logout(java.lang.String) */ @Override public void logout(String inSessionIdValue) { SessionId session = new SessionId(inSessionIdValue); rpcSessions.remove(session); sessionManager.remove(session); } /* (non-Javadoc) * @see org.marketcetera.client.rpc.RpcServerServices#validateAndReturnSession(java.lang.String) */ @Override public SessionHolder<SessionClazz> validateAndReturnSession(String inSessionIdValue) { SessionId session = new SessionId(inSessionIdValue); SessionHolder<SessionClazz> sessionInfo = sessionManager.get(session); if(sessionInfo == null) { throw new IllegalArgumentException("Invalid session: " + inSessionIdValue); // TODO } return sessionInfo; } /* (non-Javadoc) * @see org.marketcetera.client.rpc.RpcServerServices#marshall(java.lang.Object) */ @Override public String marshal(Object inObject) throws JAXBException { StringWriter output = new StringWriter(); synchronized(marshaller) { marshaller.marshal(inObject, output); } return output.toString(); } /* (non-Javadoc) * @see org.marketcetera.client.rpc.RpcServerServices#unmarshall(java.lang.String) */ @Override @SuppressWarnings("unchecked") public <Clazz> Clazz unmarshall(String inData) throws JAXBException { synchronized(unmarshaller) { return (Clazz)unmarshaller.unmarshal(new StringReader(inData)); } } /** * manages sessions */ private SessionManager<SessionClazz> sessionManager; /** * provides authentication services */ private Authenticator authenticator; /** * indicates if the server is running */ private final AtomicBoolean running = new AtomicBoolean(false); /** * manages server calls */ private RpcServerCallExecutor executor; /** * channel handle */ private ChannelFuture channelToken; /** * authenticated RPC sessions */ private final Map<SessionId,String> rpcSessions = Maps.newConcurrentMap(); /** * hostname to bind */ private String hostname; /** * port to bind */ private int port; /** * send buffer size */ private int sendBufferSize = 1048576; /** * receive buffer size */ private int receiveBufferSize = 1048576; /** * indicates whether to employ Nagle's algorithm */ private boolean noDelay = true; /** * minimum size for the RCP server thread pool */ private int threadPoolCore = 10; /** * maximum size for the RCP server thread pool */ private int threadPoolMax = 200; /** * provides context classes for marshalling and unmarshalling */ private ContextClassProvider contextClassProvider; /** * context used to control the marshaller and unmarshaller */ private JAXBContext reportContext; /** * marshals data for JAXB */ private Marshaller marshaller; /** * unmarshals data for JAXB */ private Unmarshaller unmarshaller; /** * RPC services to manage */ private final List<RpcServiceSpec<SessionClazz>> serviceSpecs = Lists.newArrayList(); }