package org.marketcetera.util.rpc;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import java.io.StringReader;
import java.io.StringWriter;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller;
import org.apache.commons.lang.Validate;
import org.marketcetera.util.log.SLF4JLoggerProxy;
import org.marketcetera.util.misc.ClassVersion;
import org.marketcetera.util.ws.ContextClassProvider;
import org.marketcetera.util.ws.stateful.Authenticator;
import org.marketcetera.util.ws.stateful.SessionHolder;
import org.marketcetera.util.ws.stateful.SessionManager;
import org.marketcetera.util.ws.stateless.StatelessClientContext;
import org.marketcetera.util.ws.tags.AppId;
import org.marketcetera.util.ws.tags.NodeId;
import org.marketcetera.util.ws.tags.SessionId;
import org.marketcetera.util.ws.tags.VersionId;
import org.marketcetera.util.ws.wrappers.LocaleWrapper;
import org.springframework.context.Lifecycle;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.protobuf.BlockingService;
import com.googlecode.protobuf.pro.duplex.PeerInfo;
import com.googlecode.protobuf.pro.duplex.execute.RpcServerCallExecutor;
import com.googlecode.protobuf.pro.duplex.execute.ThreadPoolCallExecutor;
import com.googlecode.protobuf.pro.duplex.server.DuplexTcpServerPipelineFactory;
import com.googlecode.protobuf.pro.duplex.util.RenamingThreadFactoryProxy;
/* $License$ */
/**
* Provides RPC services.
*
* @author <a href="mailto:colin@marketcetera.com">Colin DuPlantis</a>
* @version $Id: RpcServer.java 16901 2014-05-11 16:14:11Z colin $
* @since 2.4.0
*/
@ClassVersion("$Id: RpcServer.java 16901 2014-05-11 16:14:11Z colin $")
public class RpcServer<SessionClazz>
implements Lifecycle,RpcServerServices<SessionClazz>
{
/* (non-Javadoc)
* @see org.springframework.context.Lifecycle#isRunning()
*/
@Override
public boolean isRunning()
{
return running.get();
}
/* (non-Javadoc)
* @see org.springframework.context.Lifecycle#start()
*/
@Override
@PostConstruct
public synchronized void start()
{
Validate.notNull(hostname);
Validate.isTrue(port > 0 && port < 65536);
Validate.notNull(sessionManager);
Validate.notNull(authenticator);
Validate.isTrue(threadPoolCore > 0);
Validate.isTrue(threadPoolMax > 0);
Validate.isTrue(threadPoolMax >= threadPoolCore);
Validate.isTrue(sendBufferSize > 0);
Validate.isTrue(receiveBufferSize > 0);
Validate.notEmpty(serviceSpecs);
Messages.SERVER_STARTING.info(this,
hostname,
port);
if(isRunning()) {
stop();
}
try {
reportContext = JAXBContext.newInstance(contextClassProvider==null?new Class<?>[0]:contextClassProvider.getContextClasses());
marshaller = reportContext.createMarshaller();
unmarshaller = reportContext.createUnmarshaller();
} catch (JAXBException e) {
SLF4JLoggerProxy.error(this,
e);
throw new RuntimeException(e);
}
PeerInfo serverInfo = new PeerInfo(getRpcHostname(),
getRpcPort());
executor = new ThreadPoolCallExecutor(threadPoolCore,
threadPoolMax);
DuplexTcpServerPipelineFactory serverFactory = new DuplexTcpServerPipelineFactory(serverInfo);
serverFactory.setRpcServerCallExecutor(executor);
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(new NioEventLoopGroup(0,
new RenamingThreadFactoryProxy("boss",
Executors.defaultThreadFactory())),
new NioEventLoopGroup(0,
new RenamingThreadFactoryProxy("worker",
Executors.defaultThreadFactory())));
bootstrap.channel(NioServerSocketChannel.class);
bootstrap.childHandler(serverFactory);
bootstrap.localAddress(serverInfo.getPort());
bootstrap.option(ChannelOption.SO_SNDBUF,
sendBufferSize);
bootstrap.option(ChannelOption.SO_RCVBUF,
receiveBufferSize);
bootstrap.childOption(ChannelOption.SO_RCVBUF,
receiveBufferSize);
bootstrap.childOption(ChannelOption.SO_SNDBUF,
sendBufferSize);
bootstrap.option(ChannelOption.TCP_NODELAY,
noDelay);
for(RpcServiceSpec<SessionClazz> serviceSpec : serviceSpecs) {
serviceSpec.setRpcServerServices(this);
BlockingService activeService = serviceSpec.generateService();
serverFactory.getRpcServiceRegistry().registerService(activeService);
Messages.SERVICE_STARTING.info(this,
serviceSpec.getDescription());
}
channelToken = bootstrap.bind();
while(!channelToken.isDone()) {
try {
Thread.sleep(250);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
// TODO throw exception?
running.set(channelToken.isSuccess());
//RpcClientConnectionRegistry clientRegistry = new RpcClientConnectionRegistry();
//serverFactory.registerConnectionEventListener(clientRegistry);
}
/* (non-Javadoc)
* @see org.springframework.context.Lifecycle#stop()
*/
@Override
@PreDestroy
public synchronized void stop()
{
Messages.SERVER_STOPPING.info(this);
try {
try {
if(executor != null) {
executor.shutdownNow();
}
} catch (Exception ignored) {}
try {
if(channelToken != null && channelToken.channel() != null) {
channelToken.channel().close();
}
} catch (Exception ignored) {}
for(SessionId session : rpcSessions.keySet()) {
try {
sessionManager.remove(session);
} catch (Exception ignored) {}
}
} finally {
rpcSessions.clear();
channelToken = null;
executor = null;
reportContext = null;
marshaller = null;
unmarshaller = null;
running.set(false);
}
}
/**
* Get the serviceSpecs value.
*
* @return a <code>List<RpcServiceSpec></code> value
*/
public List<RpcServiceSpec<SessionClazz>> getServiceSpecs()
{
return serviceSpecs;
}
/**
* Sets the serviceSpecs value.
*
* @param inServiceSpecs a <code>List<RpcServiceSpec></code> value
*/
public void setServiceSpecs(List<RpcServiceSpec<SessionClazz>> inServiceSpecs)
{
serviceSpecs.clear();
if(inServiceSpecs != null) {
serviceSpecs.addAll(inServiceSpecs);
}
}
/**
* Get the sendBufferSize value.
*
* @return an <code>int</code> value
*/
public int getSendBufferSize()
{
return sendBufferSize;
}
/**
* Sets the sendBufferSize value.
*
* @param inSendBufferSize an <code>int</code> value
*/
public void setSendBufferSize(int inSendBufferSize)
{
sendBufferSize = inSendBufferSize;
}
/**
* Get the receiveBufferSize value.
*
* @return an <code>int</code> value
*/
public int getReceiveBufferSize()
{
return receiveBufferSize;
}
/**
* Sets the receiveBufferSize value.
*
* @param inReceiveBufferSize an <code>int</code> value
*/
public void setReceiveBufferSize(int inReceiveBufferSize)
{
receiveBufferSize = inReceiveBufferSize;
}
/**
* Get the noDelay value.
*
* @return a <code>boolean</code> value
*/
public boolean getNoDelay()
{
return noDelay;
}
/**
* Sets the noDelay value.
*
* @param inNoDelay a <code>boolean</code> value
*/
public void setNoDelay(boolean inNoDelay)
{
noDelay = inNoDelay;
}
/**
* Get the threadPoolCore value.
*
* @return an <code>int</code> value
*/
public int getThreadPoolCore()
{
return threadPoolCore;
}
/**
* Sets the threadPoolCore value.
*
* @param inThreadPoolCore an <code>int</code> value
*/
public void setThreadPoolCore(int inThreadPoolCore)
{
threadPoolCore = inThreadPoolCore;
}
/**
* Get the threadPoolMax value.
*
* @return an <code>int</code> value
*/
public int getThreadPoolMax()
{
return threadPoolMax;
}
/**
* Sets the threadPoolMax value.
*
* @param inThreadPoolMax an <code>int</code> value
*/
public void setThreadPoolMax(int inThreadPoolMax)
{
threadPoolMax = inThreadPoolMax;
}
/**
* Get the rpcHostname value.
*
* @return a <code>String</code> value
*/
public String getRpcHostname()
{
return hostname;
}
/**
* Sets the rpcHostname value.
*
* @param inRpcHostname a <code>String</code> value
*/
public void setHostname(String inRpcHostname)
{
hostname = inRpcHostname;
}
/**
* Get the rpcPort value.
*
* @return an <code>int</code> value
*/
public int getRpcPort()
{
return port;
}
/**
* Sets the rpcPort value.
*
* @param inRpcPort an <code>int</code> value
*/
public void setPort(int inRpcPort)
{
port = inRpcPort;
}
/**
* Get the authenticator value.
*
* @return an <code>Authenticator</code> value
*/
public Authenticator getAuthenticator()
{
return authenticator;
}
/**
* Sets the authenticator value.
*
* @param inAuthenticator an <code>Authenticator</code> value
*/
public void setAuthenticator(Authenticator inAuthenticator)
{
authenticator = inAuthenticator;
}
/**
* Get the sessionManager value.
*
* @return a <code>SessionManager<SessionClazz></code> value
*/
public SessionManager<SessionClazz> getSessionManager()
{
return sessionManager;
}
/**
* Sets the sessionManager value.
*
* @param inSessionManager a <code>SessionManager<SessionClazz></code> value
*/
public void setSessionManager(SessionManager<SessionClazz> inSessionManager)
{
sessionManager = inSessionManager;
}
/**
* Get the contextClassProvider value.
*
* @return a <code>ContextClassProvider</code> value
*/
public ContextClassProvider getContextClassProvider()
{
return contextClassProvider;
}
/**
* Sets the contextClassProvider value.
*
* @param inContextClassProvider a <code>ContextClassProvider</code> value
*/
public void setContextClassProvider(ContextClassProvider inContextClassProvider)
{
contextClassProvider = inContextClassProvider;
}
/* (non-Javadoc)
* @see org.marketcetera.client.rpc.RpcServerServices#login(org.marketcetera.client.rpc.Credentials)
*/
@Override
public SessionId login(RpcCredentials inCredentials)
{
StatelessClientContext context = new StatelessClientContext();
context.setAppId(new AppId(inCredentials.getAppId()));
context.setClientId(new NodeId(inCredentials.getClientId()));
context.setVersionId(new VersionId(inCredentials.getVersionId()));
LocaleWrapper locale = new LocaleWrapper(inCredentials.getLocale());
context.setLocale(locale);
authenticator.shouldAllow(context,
inCredentials.getUsername(),
inCredentials.getPassword().toCharArray());
SessionId sessionId = SessionId.generate();
SessionHolder<SessionClazz> sessionHolder = new SessionHolder<SessionClazz>(inCredentials.getUsername(),
context);
sessionManager.put(sessionId,
sessionHolder);
rpcSessions.put(sessionId,
inCredentials.getUsername());
return sessionId;
}
/* (non-Javadoc)
* @see org.marketcetera.client.rpc.RpcServerServices#logout(java.lang.String)
*/
@Override
public void logout(String inSessionIdValue)
{
SessionId session = new SessionId(inSessionIdValue);
rpcSessions.remove(session);
sessionManager.remove(session);
}
/* (non-Javadoc)
* @see org.marketcetera.client.rpc.RpcServerServices#validateAndReturnSession(java.lang.String)
*/
@Override
public SessionHolder<SessionClazz> validateAndReturnSession(String inSessionIdValue)
{
SessionId session = new SessionId(inSessionIdValue);
SessionHolder<SessionClazz> sessionInfo = sessionManager.get(session);
if(sessionInfo == null) {
throw new IllegalArgumentException("Invalid session: " + inSessionIdValue); // TODO
}
return sessionInfo;
}
/* (non-Javadoc)
* @see org.marketcetera.client.rpc.RpcServerServices#marshall(java.lang.Object)
*/
@Override
public String marshal(Object inObject)
throws JAXBException
{
StringWriter output = new StringWriter();
synchronized(marshaller) {
marshaller.marshal(inObject,
output);
}
return output.toString();
}
/* (non-Javadoc)
* @see org.marketcetera.client.rpc.RpcServerServices#unmarshall(java.lang.String)
*/
@Override
@SuppressWarnings("unchecked")
public <Clazz> Clazz unmarshall(String inData)
throws JAXBException
{
synchronized(unmarshaller) {
return (Clazz)unmarshaller.unmarshal(new StringReader(inData));
}
}
/**
* manages sessions
*/
private SessionManager<SessionClazz> sessionManager;
/**
* provides authentication services
*/
private Authenticator authenticator;
/**
* indicates if the server is running
*/
private final AtomicBoolean running = new AtomicBoolean(false);
/**
* manages server calls
*/
private RpcServerCallExecutor executor;
/**
* channel handle
*/
private ChannelFuture channelToken;
/**
* authenticated RPC sessions
*/
private final Map<SessionId,String> rpcSessions = Maps.newConcurrentMap();
/**
* hostname to bind
*/
private String hostname;
/**
* port to bind
*/
private int port;
/**
* send buffer size
*/
private int sendBufferSize = 1048576;
/**
* receive buffer size
*/
private int receiveBufferSize = 1048576;
/**
* indicates whether to employ Nagle's algorithm
*/
private boolean noDelay = true;
/**
* minimum size for the RCP server thread pool
*/
private int threadPoolCore = 10;
/**
* maximum size for the RCP server thread pool
*/
private int threadPoolMax = 200;
/**
* provides context classes for marshalling and unmarshalling
*/
private ContextClassProvider contextClassProvider;
/**
* context used to control the marshaller and unmarshaller
*/
private JAXBContext reportContext;
/**
* marshals data for JAXB
*/
private Marshaller marshaller;
/**
* unmarshals data for JAXB
*/
private Unmarshaller unmarshaller;
/**
* RPC services to manage
*/
private final List<RpcServiceSpec<SessionClazz>> serviceSpecs = Lists.newArrayList();
}