// =================================================================================================
// Copyright 2011 Twitter, Inc.
// -------------------------------------------------------------------------------------------------
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this work except in compliance with the License.
// You may obtain a copy of the License in the LICENSE file, or at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =================================================================================================
package com.twitter.common.thrift;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.net.InetSocketAddress;
import java.util.Collection;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.thrift.async.TAsyncClient;
import org.apache.thrift.async.TAsyncClientManager;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.transport.TNonblockingTransport;
import org.apache.thrift.transport.TTransport;
import com.twitter.common.base.Closure;
import com.twitter.common.base.Closures;
import com.twitter.common.base.MorePreconditions;
import com.twitter.common.net.loadbalancing.LeastConnectedStrategy;
import com.twitter.common.net.loadbalancing.LoadBalancer;
import com.twitter.common.net.loadbalancing.LoadBalancerImpl;
import com.twitter.common.net.loadbalancing.LoadBalancingStrategy;
import com.twitter.common.net.loadbalancing.MarkDeadStrategyWithHostCheck;
import com.twitter.common.net.loadbalancing.TrafficMonitorAdapter;
import com.twitter.common.net.monitoring.TrafficMonitor;
import com.twitter.common.net.pool.Connection;
import com.twitter.common.net.pool.ConnectionPool;
import com.twitter.common.net.pool.DynamicHostSet;
import com.twitter.common.net.pool.DynamicPool;
import com.twitter.common.net.pool.MetaPool;
import com.twitter.common.net.pool.ObjectPool;
import com.twitter.common.quantity.Amount;
import com.twitter.common.quantity.Time;
import com.twitter.common.stats.Stats;
import com.twitter.common.stats.StatsProvider;
import com.twitter.common.thrift.ThriftConnectionFactory.TransportType;
import com.twitter.common.util.BackoffDecider;
import com.twitter.common.util.BackoffStrategy;
import com.twitter.common.util.TruncatedBinaryBackoff;
import com.twitter.common.util.concurrent.ForwardingExecutorService;
import com.twitter.thrift.ServiceInstance;
/**
* A utility that provides convenience methods to build common {@link Thrift}s.
*
* The thrift factory allows you to specify parameters that define how the client connects to
* and communicates with servers, such as the transport type, connection settings, and load
* balancing. Request-level settings like sync/async and retries should be set on the
* {@link Thrift} instance that this factory will create.
*
* The factory will attempt to provide reasonable defaults to allow the caller to minimize the
* amount of necessary configuration. Currently, the default behavior includes:
*
* <ul>
* <li> A test lease/release for each host will be performed every second
* {@link #withDeadConnectionRestoreInterval(Amount)}
* <li> At most 50 connections will be established to each host
* {@link #withMaxConnectionsPerEndpoint(int)}
* <li> Unframed transport {@link #useFramedTransport(boolean)}
* <li> A load balancing strategy that will mark hosts dead and prefer least-connected hosts.
* Hosts are marked dead if the most recent connection attempt was a failure or else based on
* the windowed error rate of attempted RPCs. If the error rate for a connected host exceeds
* 20% over the last second, the host will be disabled for 2 seconds ascending up to 10 seconds
* if the elevated error rate persists.
* {@link #withLoadBalancingStrategy(LoadBalancingStrategy)}
* <li> Statistics are reported through {@link Stats}
* {@link #withStatsProvider(StatsProvider)}
* <li> A service name matching the thrift interface name {@link #withServiceName(String)}
* </ul>
*
* @author John Sirois
*/
public class ThriftFactory<T> {
private static final Amount<Long,Time> DEFAULT_DEAD_TARGET_RESTORE_INTERVAL =
Amount.of(1L, Time.SECONDS);
private static final int DEFAULT_MAX_CONNECTIONS_PER_ENDPOINT = 50;
private Class<T> serviceInterface;
private Function<TTransport, T> clientFactory;
private int maxConnectionsPerEndpoint;
private Amount<Long,Time> connectionRestoreInterval;
private boolean framedTransport;
private LoadBalancingStrategy<InetSocketAddress> loadBalancingStrategy = null;
private final TrafficMonitor<InetSocketAddress> monitor;
private Amount<Long,Time> socketTimeout = null;
private Closure<Connection<TTransport, InetSocketAddress>> postCreateCallback = Closures.noop();
private StatsProvider statsProvider = Stats.STATS_PROVIDER;
private String serviceName;
private boolean sslTransport;
public static <T> ThriftFactory<T> create(Class<T> serviceInterface) {
return new ThriftFactory<T>(serviceInterface);
}
/**
* Creates a default factory that will use unframed blocking transport.
*
* @param serviceInterface The interface of the thrift service to make a client for.
*/
private ThriftFactory(Class<T> serviceInterface) {
this.serviceInterface = Thrift.checkServiceInterface(serviceInterface);
this.maxConnectionsPerEndpoint = DEFAULT_MAX_CONNECTIONS_PER_ENDPOINT;
this.connectionRestoreInterval = DEFAULT_DEAD_TARGET_RESTORE_INTERVAL;
this.framedTransport = false;
this.monitor = new TrafficMonitor<InetSocketAddress>(serviceInterface.getName());
this.serviceName = serviceInterface.getEnclosingClass().getSimpleName();
this.sslTransport = false;
}
private void checkBaseState() {
Preconditions.checkArgument(maxConnectionsPerEndpoint > 0,
"Must allow at least 1 connection per endpoint; %s specified", maxConnectionsPerEndpoint);
}
public TrafficMonitor<InetSocketAddress> getMonitor() {
return monitor;
}
/**
* Creates the thrift client, and initializes connection pools.
*
* @param backends Backends to connect to.
* @return A new thrift client.
*/
public Thrift<T> build(Set<InetSocketAddress> backends) {
checkBaseState();
MorePreconditions.checkNotBlank(backends);
ManagedThreadPool managedThreadPool = createManagedThreadpool(backends.size());
LoadBalancer<InetSocketAddress> loadBalancer = createLoadBalancer();
Function<TTransport, T> clientFactory = getClientFactory();
ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool =
createConnectionPool(backends, loadBalancer, managedThreadPool, false);
return new Thrift<T>(managedThreadPool, connectionPool, loadBalancer, serviceName,
serviceInterface, clientFactory, false, sslTransport);
}
/**
* Creates a synchronous thrift client that will communicate with a dynamic host set.
*
* @param hostSet The host set to use as a backend.
* @return A thrift client.
* @throws ThriftFactoryException If an error occurred while creating the client.
*/
public Thrift<T> build(DynamicHostSet<ServiceInstance> hostSet) throws ThriftFactoryException {
checkBaseState();
Preconditions.checkNotNull(hostSet);
ManagedThreadPool managedThreadPool = createManagedThreadpool(1);
LoadBalancer<InetSocketAddress> loadBalancer = createLoadBalancer();
Function<TTransport, T> clientFactory = getClientFactory();
ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool =
createConnectionPool(hostSet, loadBalancer, managedThreadPool, false);
return new Thrift<T>(managedThreadPool, connectionPool, loadBalancer, serviceName,
serviceInterface, clientFactory, false, sslTransport);
}
private ManagedThreadPool createManagedThreadpool(int initialEndpointCount) {
return new ManagedThreadPool(serviceName, initialEndpointCount, maxConnectionsPerEndpoint);
}
/**
* A finite thread pool that monitors backend choice events to dynamically resize. This
* {@link java.util.concurrent.ExecutorService} implementation immediately rejects requests when
* there are no more available worked threads (requests are not queued).
*/
private static class ManagedThreadPool extends ForwardingExecutorService<ThreadPoolExecutor>
implements Closure<Collection<InetSocketAddress>> {
private static final Logger LOG = Logger.getLogger(ManagedThreadPool.class.getName());
private static ThreadPoolExecutor createThreadPool(String serviceName, int initialSize) {
ThreadFactory threadFactory =
new ThreadFactoryBuilder()
.setNameFormat("Thrift[" +serviceName + "][%d]")
.setDaemon(true)
.build();
return new ThreadPoolExecutor(initialSize, initialSize, 0, TimeUnit.MILLISECONDS,
new SynchronousQueue<Runnable>(), threadFactory);
}
private final int maxConnectionsPerEndpoint;
public ManagedThreadPool(String serviceName, int initialEndpointCount,
int maxConnectionsPerEndpoint) {
super(createThreadPool(serviceName, initialEndpointCount * maxConnectionsPerEndpoint));
this.maxConnectionsPerEndpoint = maxConnectionsPerEndpoint;
setRejectedExecutionHandler(initialEndpointCount);
}
private void setRejectedExecutionHandler(int endpointCount) {
final String message =
String.format("All %d x %d connections in use", endpointCount, maxConnectionsPerEndpoint);
delegate.setRejectedExecutionHandler(new RejectedExecutionHandler() {
@Override public void rejectedExecution(Runnable runnable, ThreadPoolExecutor executor) {
throw new RejectedExecutionException(message);
}
});
}
@Override
public void execute(Collection<InetSocketAddress> chosenBackends) {
int previousPoolSize = delegate.getMaximumPoolSize();
/*
* In the case of no available backends, we need to make sure we pass in a positive pool
* size to our delegate. In particular, java.util.concurrent.ThreadPoolExecutor does not
* accept zero as a valid core or max pool size.
*/
int backendCount = Math.max(chosenBackends.size(), 1);
int newPoolSize = backendCount * maxConnectionsPerEndpoint;
if (previousPoolSize != newPoolSize) {
LOG.info(String.format("Re-sizing deadline thread pool from: %d to: %d",
previousPoolSize, newPoolSize));
if (previousPoolSize < newPoolSize) { // Don't cross the beams!
delegate.setMaximumPoolSize(newPoolSize);
delegate.setCorePoolSize(newPoolSize);
} else {
delegate.setCorePoolSize(newPoolSize);
delegate.setMaximumPoolSize(newPoolSize);
}
setRejectedExecutionHandler(backendCount);
}
}
}
/**
* Creates an asynchronous thrift client that will communicate with a fixed set of backends.
*
* @param backends Backends to connect to.
* @return A thrift client.
* @throws ThriftFactoryException If an error occurred while creating the client.
*/
public Thrift<T> buildAsync(Set<InetSocketAddress> backends) throws ThriftFactoryException {
checkBaseState();
MorePreconditions.checkNotBlank(backends);
LoadBalancer<InetSocketAddress> loadBalancer = createLoadBalancer();
Closure<Collection<InetSocketAddress>> noop = Closures.noop();
Function<TTransport, T> asyncClientFactory = getAsyncClientFactory();
ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool =
createConnectionPool(backends, loadBalancer, noop, true);
return new Thrift<T>(connectionPool, loadBalancer,
serviceName, serviceInterface, asyncClientFactory, true);
}
/**
* Creates an asynchronous thrift client that will communicate with a dynamic host set.
*
* @param hostSet The host set to use as a backend.
* @return A thrift client.
* @throws ThriftFactoryException If an error occurred while creating the client.
*/
public Thrift<T> buildAsync(DynamicHostSet<ServiceInstance> hostSet)
throws ThriftFactoryException {
checkBaseState();
Preconditions.checkNotNull(hostSet);
LoadBalancer<InetSocketAddress> loadBalancer = createLoadBalancer();
Closure<Collection<InetSocketAddress>> noop = Closures.noop();
Function<TTransport, T> asyncClientFactory = getAsyncClientFactory();
ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool =
createConnectionPool(hostSet, loadBalancer, noop, true);
return new Thrift<T>(connectionPool, loadBalancer,
serviceName, serviceInterface, asyncClientFactory, true);
}
/**
* Prepare the client factory, which will create client class instances from transports.
*
* @return The client factory to use.
*/
private Function<TTransport, T> getClientFactory() {
return clientFactory == null ? createClientFactory(serviceInterface) : clientFactory;
}
/**
* Prepare the async client factory, which will create client class instances from transports.
*
* @return The client factory to use.
* @throws ThriftFactoryException If there was a problem creating the factory.
*/
private Function<TTransport, T> getAsyncClientFactory() throws ThriftFactoryException {
try {
return clientFactory == null ? createAsyncClientFactory(serviceInterface) : clientFactory;
} catch (IOException e) {
throw new ThriftFactoryException("Failed to create async client factory.", e);
}
}
private ObjectPool<Connection<TTransport, InetSocketAddress>> createConnectionPool(
Set<InetSocketAddress> backends, LoadBalancer<InetSocketAddress> loadBalancer,
Closure<Collection<InetSocketAddress>> onBackendsChosen, boolean nonblocking) {
ImmutableMap.Builder<InetSocketAddress, ObjectPool<Connection<TTransport, InetSocketAddress>>>
backendBuilder = ImmutableMap.builder();
for (InetSocketAddress backend : backends) {
backendBuilder.put(backend, createConnectionPool(backend, nonblocking));
}
return new MetaPool<TTransport, InetSocketAddress>(backendBuilder.build(),
loadBalancer, onBackendsChosen, connectionRestoreInterval);
}
private ObjectPool<Connection<TTransport, InetSocketAddress>> createConnectionPool(
DynamicHostSet<ServiceInstance> hostSet, LoadBalancer<InetSocketAddress> loadBalancer,
Closure<Collection<InetSocketAddress>> onBackendsChosen,
final boolean nonblocking) throws ThriftFactoryException {
Function<InetSocketAddress, ObjectPool<Connection<TTransport, InetSocketAddress>>>
endpointPoolFactory =
new Function<InetSocketAddress, ObjectPool<Connection<TTransport, InetSocketAddress>>>() {
@Override public ObjectPool<Connection<TTransport, InetSocketAddress>> apply(
InetSocketAddress endpoint) {
return createConnectionPool(endpoint, nonblocking);
}
};
try {
return new DynamicPool<ServiceInstance, TTransport, InetSocketAddress>(hostSet,
endpointPoolFactory, loadBalancer, onBackendsChosen, connectionRestoreInterval,
Util.GET_ADDRESS, Util.IS_ALIVE);
} catch (DynamicHostSet.MonitorException e) {
throw new ThriftFactoryException("Failed to monitor host set.", e);
}
}
private ObjectPool<Connection<TTransport, InetSocketAddress>> createConnectionPool(
InetSocketAddress backend, boolean nonblocking) {
ThriftConnectionFactory connectionFactory = new ThriftConnectionFactory(
backend, maxConnectionsPerEndpoint, TransportType.get(framedTransport, nonblocking),
socketTimeout, postCreateCallback, sslTransport);
return new ConnectionPool<Connection<TTransport, InetSocketAddress>>(connectionFactory,
statsProvider);
}
@VisibleForTesting
public ThriftFactory<T> withClientFactory(Function<TTransport, T> clientFactory) {
this.clientFactory = Preconditions.checkNotNull(clientFactory);
return this;
}
public ThriftFactory<T> withSslEnabled() {
this.sslTransport = true;
return this;
}
/**
* Specifies the maximum number of connections that should be made to any single endpoint.
*
* @param maxConnectionsPerEndpoint Maximum number of connections per endpoint.
* @return A reference to the factory.
*/
public ThriftFactory<T> withMaxConnectionsPerEndpoint(int maxConnectionsPerEndpoint) {
Preconditions.checkArgument(maxConnectionsPerEndpoint > 0);
this.maxConnectionsPerEndpoint = maxConnectionsPerEndpoint;
return this;
}
/**
* Specifies the interval at which dead endpoint connections should be checked and revived.
*
* @param connectionRestoreInterval the time interval to check.
* @return A reference to the factory.
*/
public ThriftFactory<T> withDeadConnectionRestoreInterval(
Amount<Long, Time> connectionRestoreInterval) {
Preconditions.checkNotNull(connectionRestoreInterval);
Preconditions.checkArgument(connectionRestoreInterval.getValue() >= 0,
"A negative interval is invalid: %s", connectionRestoreInterval);
this.connectionRestoreInterval = connectionRestoreInterval;
return this;
}
/**
* Instructs the factory whether framed transport should be used.
*
* @param framedTransport Whether to use framed transport.
* @return A reference to the factory.
*/
public ThriftFactory<T> useFramedTransport(boolean framedTransport) {
this.framedTransport = framedTransport;
return this;
}
/**
* Specifies the load balancer to use when interacting with multiple backends.
*
* @param strategy Load balancing strategy.
* @return A reference to the factory.
*/
public ThriftFactory<T> withLoadBalancingStrategy(
LoadBalancingStrategy<InetSocketAddress> strategy) {
this.loadBalancingStrategy = Preconditions.checkNotNull(strategy);
return this;
}
private LoadBalancer<InetSocketAddress> createLoadBalancer() {
if (loadBalancingStrategy == null) {
loadBalancingStrategy = createDefaultLoadBalancingStrategy();
}
return LoadBalancerImpl.create(TrafficMonitorAdapter.create(loadBalancingStrategy, monitor));
}
private LoadBalancingStrategy<InetSocketAddress> createDefaultLoadBalancingStrategy() {
Function<InetSocketAddress, BackoffDecider> backoffFactory =
new Function<InetSocketAddress, BackoffDecider>() {
@Override public BackoffDecider apply(InetSocketAddress socket) {
BackoffStrategy backoffStrategy = new TruncatedBinaryBackoff(
Amount.of(2L, Time.SECONDS), Amount.of(10L, Time.SECONDS));
return BackoffDecider.builder(socket.toString())
.withTolerateFailureRate(0.2)
.withRequestWindow(Amount.of(1L, Time.SECONDS))
.withSeedSize(5)
.withStrategy(backoffStrategy)
.withRecoveryType(BackoffDecider.RecoveryType.FULL_CAPACITY)
.withStatsProvider(statsProvider)
.build();
}
};
return new MarkDeadStrategyWithHostCheck<InetSocketAddress>(
new LeastConnectedStrategy<InetSocketAddress>(), backoffFactory);
}
/**
* Specifies the net read/write timeout to set via SO_TIMEOUT on the thrift blocking client
* or AsyncClient.setTimeout on the thrift async client. Defaults to the connectTimeout on
* the blocking client if not set.
*
* @param socketTimeout timeout on thrift i/o operations
* @return A reference to the factory.
*/
public ThriftFactory<T> withSocketTimeout(Amount<Long, Time> socketTimeout) {
this.socketTimeout = Preconditions.checkNotNull(socketTimeout);
Preconditions.checkArgument(socketTimeout.as(Time.MILLISECONDS) >= 0);
return this;
}
/**
* Specifies the callback to notify when a connection has been created. The callback may
* be used to make thrift calls to the connection, but must not invalidate it.
* Defaults to a no-op closure.
*
* @param postCreateCallback function to setup new connections
* @return A reference to the factory.
*/
public ThriftFactory<T> withPostCreateCallback(
Closure<Connection<TTransport, InetSocketAddress>> postCreateCallback) {
this.postCreateCallback = Preconditions.checkNotNull(postCreateCallback);
return this;
}
/**
* Registers a custom stats provider to use to track various client stats.
*
* @param statsProvider the {@code StatsProvider} to use
* @return A reference to the factory.
*/
public ThriftFactory<T> withStatsProvider(StatsProvider statsProvider) {
this.statsProvider = Preconditions.checkNotNull(statsProvider);
return this;
}
/**
* Name to be passed to Thrift constructor, used in stats.
*
* @param serviceName string to use
* @return A reference to the factory.
*/
public ThriftFactory<T> withServiceName(String serviceName) {
this.serviceName = MorePreconditions.checkNotBlank(serviceName);
return this;
}
private static <T> Function<TTransport, T> createClientFactory(Class<T> serviceInterface) {
final Constructor<? extends T> implementationConstructor =
findImplementationConstructor(serviceInterface);
return new Function<TTransport, T>() {
@Override public T apply(TTransport transport) {
try {
return implementationConstructor.newInstance(new TBinaryProtocol(transport));
} catch (InstantiationException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
} catch (InvocationTargetException e) {
throw new RuntimeException(e);
}
}
};
}
private <T> Function<TTransport, T> createAsyncClientFactory(
final Class<T> serviceInterface) throws IOException {
final TAsyncClientManager clientManager = new TAsyncClientManager();
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override public void run() {
clientManager.stop();
}
});
final Constructor<? extends T> implementationConstructor =
findAsyncImplementationConstructor(serviceInterface);
return new Function<TTransport, T>() {
@Override public T apply(TTransport transport) {
Preconditions.checkNotNull(transport);
Preconditions.checkArgument(transport instanceof TNonblockingTransport,
"Invalid transport provided to client factory: " + transport.getClass());
try {
T client = implementationConstructor.newInstance(new TBinaryProtocol.Factory(),
clientManager, transport);
if (socketTimeout != null) {
((TAsyncClient) client).setTimeout(socketTimeout.as(Time.MILLISECONDS));
}
return client;
} catch (InstantiationException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
} catch (InvocationTargetException e) {
throw new RuntimeException(e);
}
}
};
}
private static <T> Constructor<? extends T> findImplementationConstructor(
final Class<T> serviceInterface) {
Class<? extends T> implementationClass = findImplementationClass(serviceInterface);
try {
return implementationClass.getConstructor(TProtocol.class);
} catch (NoSuchMethodException e) {
throw new IllegalArgumentException("Failed to find a single argument TProtocol constructor "
+ "in service client class: " + implementationClass);
}
}
private static <T> Constructor<? extends T> findAsyncImplementationConstructor(
final Class<T> serviceInterface) {
Class<? extends T> implementationClass = findImplementationClass(serviceInterface);
try {
return implementationClass.getConstructor(TProtocolFactory.class, TAsyncClientManager.class,
TNonblockingTransport.class);
} catch (NoSuchMethodException e) {
throw new IllegalArgumentException("Failed to find expected constructor "
+ "in service client class: " + implementationClass);
}
}
@SuppressWarnings("unchecked")
private static <T> Class<? extends T> findImplementationClass(final Class<T> serviceInterface) {
try {
return (Class<? extends T>)
Iterables.find(ImmutableList.copyOf(serviceInterface.getEnclosingClass().getClasses()),
new Predicate<Class<?>>() {
@Override public boolean apply(Class<?> inner) {
return !serviceInterface.equals(inner)
&& serviceInterface.isAssignableFrom(inner);
}
});
} catch (NoSuchElementException e) {
throw new IllegalArgumentException("Could not find a sibling enclosed implementation of "
+ "service interface: " + serviceInterface);
}
}
public static class ThriftFactoryException extends Exception {
public ThriftFactoryException(String msg) {
super(msg);
}
public ThriftFactoryException(String msg, Throwable t) {
super(msg, t);
}
}
}