// =================================================================================================
// Copyright 2011 Twitter, Inc.
// -------------------------------------------------------------------------------------------------
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this work except in compliance with the License.
// You may obtain a copy of the License in the LICENSE file, or at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =================================================================================================
package com.twitter.common.thrift;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.thrift.TProcessor;
import org.apache.thrift.TProcessorFactory;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TCompactProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.server.THsHaServer;
import org.apache.thrift.server.TNonblockingServer;
import org.apache.thrift.server.TServer;
import org.apache.thrift.server.TThreadPoolServer;
import org.apache.thrift.transport.TFramedTransport;
import org.apache.thrift.transport.TNonblockingServerSocket;
import org.apache.thrift.transport.TServerSocket;
import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.TTransportFactory;
import com.twitter.common.base.ExceptionalFunction;
import com.twitter.common.net.monitoring.TrafficMonitor;
import com.twitter.common.quantity.Amount;
import com.twitter.common.quantity.Time;
import com.twitter.common.stats.StatImpl;
import com.twitter.common.stats.Stats;
import com.twitter.common.thrift.monitoring.TMonitoredProcessor;
import com.twitter.common.thrift.monitoring.TMonitoredServerSocket;
import com.twitter.thrift.Status;
/**
* Implementation of common functionality to satisfy the twitter ThriftService interface.
*
* @author William Farner
*/
public abstract class ThriftServer {
private static final Logger LOG = Logger.getLogger(ThriftServer.class.getName());
public static final Supplier<TProtocolFactory> BINARY_PROTOCOL =
new Supplier<TProtocolFactory>() {
@Override public TProtocolFactory get() {
return new TBinaryProtocol.Factory(false, true);
}
};
public static final Supplier<TProtocolFactory> COMPACT_PROTOCOL =
new Supplier<TProtocolFactory>() {
@Override public TProtocolFactory get() {
return new TCompactProtocol.Factory();
}
};
public static final ExceptionalFunction<ServerSetup, TServer, TTransportException>
THREADPOOL_SERVER = new ExceptionalFunction<ServerSetup, TServer, TTransportException>() {
@Override public TServer apply(ServerSetup setup) throws TTransportException {
TThreadPoolServer.Options options = new TThreadPoolServer.Options();
if (setup.getNumThreads() > 0) {
options.minWorkerThreads = setup.getNumThreads();
options.maxWorkerThreads = setup.getNumThreads();
}
// If no socket supplied with the ServerSetup, initialize one based upon
// supplied parameters.
if (setup.getSocket() == null) {
try {
setup.setSocket(new ServerSocket(setup.getPort()));
} catch (IOException e) {
throw new TTransportException("Failed to create server socket on port " +
setup.getPort(), e);
}
}
TServerSocket unmonitoredSocket = null;
TMonitoredServerSocket monitoredSocket = null;
if (setup.isMonitored()) {
monitoredSocket = new TMonitoredServerSocket(setup.getSocket(),
setup.getSocketTimeout().as(Time.MILLISECONDS), setup.getMonitor());
} else {
unmonitoredSocket = new TServerSocket(setup.getSocket(),
setup.getSocketTimeout().as(Time.MILLISECONDS));
}
TTransportFactory transportFactory = new TTransportFactory();
TProcessor processor = setup.getProcessor();
if (setup.isMonitored()) {
processor = new TMonitoredProcessor(processor, monitoredSocket,
setup.getMonitor());
}
TServerSocket socket = setup.isMonitored() ? monitoredSocket : unmonitoredSocket;
return new TThreadPoolServer(processor, socket, transportFactory, transportFactory,
setup.getProtoFactory(), setup.getProtoFactory(), options);
}
};
/**
* This field actually provides a THsHaServer (Nonblocking server which had a thread pool)
*/
public static final ExceptionalFunction<ServerSetup, TServer, TTransportException>
NONBLOCKING_SERVER = new ExceptionalFunction<ServerSetup, TServer, TTransportException>() {
@Override public TServer apply(ServerSetup setup) throws TTransportException {
TNonblockingServerSocket socket = setup.getSocketTimeout() == null
? new TNonblockingServerSocket(setup.getPort())
: new TNonblockingServerSocket(setup.getPort(),
setup.getSocketTimeout().as(Time.MILLISECONDS));
setup.setSocket(getServerSocketFor(socket));
// just to grab defaults
THsHaServer.Options options = new THsHaServer.Options();
if (setup.getNumThreads() > 0) {
options.workerThreads = setup.getNumThreads();
}
// default queue size to num threads: max response time becomes double avg service time
final BlockingQueue<Runnable> queue =
new ArrayBlockingQueue<Runnable>(setup.getQueueSize() > 0 ? setup.getQueueSize()
: options.workerThreads);
final ThreadPoolExecutor invoker = new ThreadPoolExecutor(options.workerThreads,
options.workerThreads, options.stopTimeoutVal, options.stopTimeoutUnit, queue);
final String serverName = (setup.getName() != null ? setup.getName() : "no-name");
Stats.export(new StatImpl<Integer>(serverName + "_thrift_server_active_threads") {
@Override public Integer read() { return invoker.getActiveCount(); }
});
Stats.export(new StatImpl<Integer>(serverName + "_thrift_server_queue_size") {
@Override public Integer read() { return queue.size(); }
});
return new THsHaServer(new TProcessorFactory(setup.getProcessor()), socket,
new TFramedTransport.Factory(),
setup.getProtoFactory(), setup.getProtoFactory(), invoker,
new TNonblockingServer.Options());
}
};
/**
* Thrift doesn't provide access to socket it creates,
* this is the only way to know what ephemeral port we bound to.
* TODO: Patch thrift to provide access so we don't have to do this.
*/
@VisibleForTesting
static ServerSocket getServerSocketFor(TNonblockingServerSocket thriftSocket)
throws TTransportException {
try {
Field field = TNonblockingServerSocket.class.getDeclaredField("serverSocket_");
field.setAccessible(true);
return (ServerSocket) field.get(thriftSocket);
} catch (NoSuchFieldException e) {
throw new TTransportException("Couldn't get listening port", e);
} catch (SecurityException e) {
throw new TTransportException("Couldn't get listening port", e);
} catch (IllegalAccessException e) {
throw new TTransportException("Couldn't get listening port", e);
}
}
private final String name;
private final String version;
private ServerSetup serverSetup = null;
private TServer server = null;
// The thread that is responsible for invoking the blocking {@link TServer.serve()} call.
private Thread listeningThread;
// Current health status of the server.
private Status status = Status.STARTING;
// Time at which the server went live. Should only be used for relative (duration) tracking.
private long serverStartNanos = -1;
private final Supplier<TProtocolFactory> protoFactorySupplier;
private final ExceptionalFunction<ServerSetup, TServer, TTransportException> serverSupplier;
/**
* Creates a new default thrift server, which uses a TThreadPoolServer and
*
* @param name Name for the server.
* @param version Version identifier.
*/
public ThriftServer(String name, String version) {
this(name, version, BINARY_PROTOCOL, THREADPOOL_SERVER);
}
/**
* Creates a new thrift server with the provided configuration.
*
* @param name Name for the server.
* @param version Version identifier.
* @param protoFactorySupplier Supplier to build the protocol factory to use.
* @param serverSupplier Function to build a TServer object based on the server setup.
*/
public ThriftServer(String name, String version, Supplier<TProtocolFactory> protoFactorySupplier,
ExceptionalFunction<ServerSetup, TServer, TTransportException> serverSupplier) {
this.name = Preconditions.checkNotNull(name);
this.version = Preconditions.checkNotNull(version);
this.protoFactorySupplier = Preconditions.checkNotNull(protoFactorySupplier);
this.serverSupplier = Preconditions.checkNotNull(serverSupplier);
}
/**
* Starts the server.
* This may be called at any point except when the server is already alive. That is, it's
* allowable to start, stop, and re-start the server.
*
* @param port The port to listen on.
* @param processor The processor to handle requests.
*/
public void start(int port, TProcessor processor) {
start(new ServerSetup(name, port, processor, protoFactorySupplier.get()));
}
/**
* Starts the server.
* This may be called at any point except when the server is already alive. That is, it's
* allowable to start, stop, and re-start the server.
*
* @param setup options for server
*/
public void start(ServerSetup setup) {
Preconditions.checkNotNull(setup.getProcessor());
Preconditions.checkState(status != Status.ALIVE, "Server must only be started once.");
setStatus(Status.ALIVE);
try {
doStart(setup);
} catch (TTransportException e) {
LOG.log(Level.SEVERE, "Failed to open thrift socket.", e);
setStatus(Status.DEAD);
}
}
@VisibleForTesting
protected void doStart(ServerSetup setup) throws TTransportException {
serverSetup = setup;
server = serverSupplier.apply(setup);
serverStartNanos = System.nanoTime();
LOG.info("Starting thrift server on port " + getListeningPort());
listeningThread = new ThreadFactoryBuilder().setDaemon(false).build().newThread(new Runnable() {
@Override public void run() {
try {
server.serve();
} catch (Throwable t) {
LOG.log(Level.WARNING,
"Uncaught exception while attempting to handle service requests: " + t, t);
setStatus(Status.DEAD);
}
}
});
listeningThread.start();
}
public int getListeningPort() {
Preconditions.checkState(serverSetup != null);
Preconditions.checkState(status == Status.ALIVE);
Preconditions.checkState(serverSetup.getSocket() != null);
return serverSetup.getSocket().getLocalPort();
}
public String getName() {
return name;
}
public String getVersion() {
return version;
}
public Status getStatus() {
return status;
}
/**
* Changes the status of the server.
*
* @param status New status.
*/
protected void setStatus(Status status) {
LOG.info("Moving from status " + this.status + " to " + status);
this.status = status;
}
public long uptime() {
return TimeUnit.SECONDS.convert(System.nanoTime() - serverStartNanos, TimeUnit.NANOSECONDS);
}
/**
* Notification to the server that a shutdown request has been made, and the server is no longer
* processing requests. The implementer may veto the shutdown by throwing an exception. A veto
* would suggest a failure to terminate backend connections in a timely manner.
*
* @throws Exception If the shutdown request could not be honored.
*/
protected void tryShutdown() throws Exception {
// Default no-op.
}
/**
* Attempts to shut down the server.
* The server may be shut down at any time, though the request will be ignored if the server is
* already stopped.
*/
public void shutdown() {
if (status == Status.STOPPED) {
LOG.info("Server already stopped, shutdown request ignored.");
return;
}
LOG.info("Received shutdown request, stopping server.");
setStatus(Status.STOPPING);
if (server != null) server.stop();
server = null;
// TODO(William Farner): Figure out what happens to queued / in-process requests when the server is
// stopped. Might want to allow a sleep period for the active requests to be completed.
try {
tryShutdown();
} catch (Exception e) {
LOG.log(Level.WARNING, "Service handler vetoed shutdown request.", e);
setStatus(Status.WARNING);
return;
}
setStatus(Status.STOPPED);
}
/**
* Attempts to shut down this server, and waits for the shutdown operation to complete.
*
* @param timeout Maximum amount of time to wait for shutdown before giving up. a timeout of
* zero means wait forever.
*
* @throws InterruptedException If interrupted while waiting for shutdown.
*/
public void awaitShutdown(Amount<Long, Time> timeout) throws InterruptedException {
Preconditions.checkNotNull(timeout);
shutdown();
if (listeningThread != null) {
listeningThread.join(timeout.as(Time.MILLISECONDS));
}
}
/**
* Represents the server configuration variables needed to construct a TServer.
*/
public static final class ServerSetup {
private final String name;
private final int port;
private final TProcessor processor;
private final TProtocolFactory protoFactory;
private final int numThreads;
private final int queueSize;
private final TrafficMonitor<InetSocketAddress> monitor;
private ServerSocket socket = null;
/**
* Timeout for client sockets from accept
*/
private final Amount<Integer, Time> socketTimeout;
public ServerSetup(int port, TProcessor processor, TProtocolFactory protoFactory) {
this(port, processor, protoFactory, -1, Amount.of(0, Time.MILLISECONDS));
}
public ServerSetup(String name, int port, TProcessor processor, TProtocolFactory protoFactory) {
this(name, port, processor, protoFactory, -1, Amount.of(0, Time.MILLISECONDS));
}
public ServerSetup(int port, TProcessor processor, TProtocolFactory protoFactory,
TrafficMonitor<InetSocketAddress> monitor) {
this(null, port, processor, protoFactory, -1, Amount.of(0, Time.MILLISECONDS), monitor);
}
public ServerSetup(int port, TProcessor processor, TProtocolFactory protoFactory,
int numThreads, Amount<Integer, Time> socketTimeout) {
this(null, port, processor, protoFactory, numThreads, socketTimeout, null);
}
public ServerSetup(String name, int port, TProcessor processor, TProtocolFactory protoFactory,
int numThreads, Amount<Integer, Time> socketTimeout) {
this(name, port, processor, protoFactory, numThreads, socketTimeout, null);
}
public ServerSetup(String name, int port, TProcessor processor, TProtocolFactory protoFactory,
int numThreads, int queueSize, Amount<Integer, Time> socketTimeout) {
this(name, port, processor, protoFactory, numThreads, queueSize, socketTimeout, null);
}
public ServerSetup(String name, int port, TProcessor processor, TProtocolFactory protoFactory,
int numThreads, Amount<Integer, Time> socketTimeout,
TrafficMonitor<InetSocketAddress> monitor) {
this(name, port, processor, protoFactory, numThreads, -1, socketTimeout, monitor);
}
public ServerSetup(String name, int port, TProcessor processor, TProtocolFactory protoFactory,
int numThreads, int queueSize, Amount<Integer, Time> socketTimeout,
TrafficMonitor<InetSocketAddress> monitor) {
Preconditions.checkArgument(port >= 0 && port < 0xFFFF, "Invalid port: " + port);
Preconditions.checkArgument(numThreads != 0);
Preconditions.checkArgument(queueSize != 0);
if (socketTimeout != null) Preconditions.checkArgument(socketTimeout.getValue() >= 0);
this.name = name;
this.port = port;
this.processor = processor;
this.protoFactory = protoFactory;
this.numThreads = numThreads;
this.queueSize = queueSize;
this.socketTimeout = socketTimeout;
this.monitor = monitor;
}
public String getName() {
return name;
}
public int getPort() {
return port;
}
public int getNumThreads() {
return numThreads;
}
public int getQueueSize() {
return queueSize;
}
public Amount<Integer, Time> getSocketTimeout() {
return socketTimeout;
}
public TProcessor getProcessor() {
return processor;
}
public TProtocolFactory getProtoFactory() {
return protoFactory;
}
public ServerSocket getSocket() {
return socket;
}
public void setSocket(ServerSocket socket) {
this.socket = socket;
}
public boolean isMonitored() {
return monitor != null;
}
public TrafficMonitor<InetSocketAddress> getMonitor() {
return monitor;
}
}
}