// ================================================================================================= // Copyright 2011 Twitter, Inc. // ------------------------------------------------------------------------------------------------- // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this work except in compliance with the License. // You may obtain a copy of the License in the LICENSE file, or at: // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ================================================================================================= package com.twitter.common.thrift; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.twitter.common.base.Closure; import com.twitter.common.base.Closures; import com.twitter.common.base.MorePreconditions; import com.twitter.common.net.pool.Connection; import com.twitter.common.net.pool.ConnectionFactory; import com.twitter.common.quantity.Amount; import com.twitter.common.quantity.Time; import org.apache.thrift.transport.TFramedTransport; import org.apache.thrift.transport.TNonblockingSocket; import org.apache.thrift.transport.TSocket; import org.apache.thrift.transport.TTransport; import org.apache.thrift.transport.TTransportException; import java.io.IOException; import java.net.InetSocketAddress; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; /** * A connection factory for thrift transport connections to a given host. This connection factory * is lazy and will only create a configured maximum number of active connections - where a * {@link ConnectionFactory#create(com.twitter.common.quantity.Amount) created} connection that has * not been {@link #destroy destroyed} is considered active. * * @author John Sirois */ public class ThriftConnectionFactory implements ConnectionFactory<Connection<TTransport, InetSocketAddress>> { public enum TransportType { BLOCKING, FRAMED, NONBLOCKING; /** * Async clients implicitly use a framed transport, requiring the server they connect to to do * the same. This prevents specifying a nonblocking client without a framed transport, since * that is not compatible with thrift and would simply cause the client to blow up when making a * request. Instead, you must explicitly say useFramedTransport(true) for any buildAsync(). */ public static TransportType get(boolean framedTransport, boolean nonblocking) { if (nonblocking) { Preconditions.checkArgument(framedTransport, "nonblocking client requires a server running framed transport"); return NONBLOCKING; } return framedTransport ? FRAMED : BLOCKING; } } private static InetSocketAddress asEndpoint(String host, int port) { MorePreconditions.checkNotBlank(host); Preconditions.checkArgument(port > 0); return InetSocketAddress.createUnresolved(host, port); } private InetSocketAddress endpoint; private final int maxConnections; private final TransportType transportType; private final Amount<Long, Time> socketTimeout; private final Closure<Connection<TTransport, InetSocketAddress>> postCreateCallback; private boolean sslTransport = false; private final Set<Connection<TTransport, InetSocketAddress>> activeConnections = Sets.newSetFromMap( Maps.<Connection<TTransport, InetSocketAddress>, Boolean>newIdentityHashMap()); private volatile int lastActiveConnectionsSize = 0; private final Lock activeConnectionsWriteLock = new ReentrantLock(true); /** * Creates a thrift connection factory with a plain socket (non-framed transport). * This is the same as calling {@link #ThriftConnectionFactory(String, int, int, boolean)} with * {@code framedTransport} set to {@code false}. * * @param host Host to connect to. * @param port Port to connect on. * @param maxConnections Maximum number of connections for this host:port. */ public ThriftConnectionFactory(String host, int port, int maxConnections) { this(host, port, maxConnections, TransportType.BLOCKING); } /** * Creates a thrift connection factory. * If {@code framedTransport} is set to {@code true}, {@link TFramedTransport} will be used, * otherwise a raw {@link TSocket} will be used. * * @param host Host to connect to. * @param port Port to connect on. * @param maxConnections Maximum number of connections for this host:port. * @param framedTransport Whether to use framed or blocking transport. */ public ThriftConnectionFactory(String host, int port, int maxConnections, boolean framedTransport) { this(asEndpoint(host, port), maxConnections, TransportType.get(framedTransport, false)); } /** * Creates a thrift connection factory. * If {@code framedTransport} is set to {@code true}, {@link TFramedTransport} will be used, * otherwise a raw {@link TSocket} will be used. * * @param endpoint Endpoint to connect to. * @param maxConnections Maximum number of connections for this host:port. * @param framedTransport Whether to use framed or blocking transport. */ public ThriftConnectionFactory(InetSocketAddress endpoint, int maxConnections, boolean framedTransport) { this(endpoint, maxConnections, TransportType.get(framedTransport, false)); } /** * Creates a thrift connection factory. * If {@code framedTransport} is set to {@code true}, {@link TFramedTransport} will be used, * otherwise a raw {@link TSocket} will be used. * If {@code nonblocking} is set to {@code true}, {@link TNonblockingSocket} will be used, * otherwise a raw {@link TSocket} will be used. * Timeouts are ignored when nonblocking transport is used. * * @param host Host to connect to. * @param port Port to connect on. * @param maxConnections Maximum number of connections for this host:port. * @param transportType Whether to use normal blocking, framed blocking, or non-blocking * (implicitly framed) transport. */ public ThriftConnectionFactory(String host, int port, int maxConnections, TransportType transportType) { this(host, port, maxConnections, transportType, null); } /** * Creates a thrift connection factory. * If {@code framedTransport} is set to {@code true}, {@link TFramedTransport} will be used, * otherwise a raw {@link TSocket} will be used. * If {@code nonblocking} is set to {@code true}, {@link TNonblockingSocket} will be used, * otherwise a raw {@link TSocket} will be used. * Timeouts are ignored when nonblocking transport is used. * * @param host Host to connect to. * @param port Port to connect on. * @param maxConnections Maximum number of connections for this host:port. * @param transportType Whether to use normal blocking, framed blocking, or non-blocking * (implicitly framed) transport. * @param socketTimeout timeout on thrift i/o operations, or null to default to connectTimeout o * the blocking client. */ public ThriftConnectionFactory(String host, int port, int maxConnections, TransportType transportType, Amount<Long, Time> socketTimeout) { this(asEndpoint(host, port), maxConnections, transportType, socketTimeout); } public ThriftConnectionFactory(InetSocketAddress endpoint, int maxConnections, TransportType transportType) { this(endpoint, maxConnections, transportType, null); } /** * Creates a thrift connection factory. * If {@code framedTransport} is set to {@code true}, {@link TFramedTransport} will be used, * otherwise a raw {@link TSocket} will be used. * If {@code nonblocking} is set to {@code true}, {@link TNonblockingSocket} will be used, * otherwise a raw {@link TSocket} will be used. * Timeouts are ignored when nonblocking transport is used. * * @param endpoint Endpoint to connect to. * @param maxConnections Maximum number of connections for this host:port. * @param transportType Whether to use normal blocking, framed blocking, or non-blocking * (implicitly framed) transport. * @param socketTimeout timeout on thrift i/o operations, or null to default to connectTimeout o * the blocking client. */ public ThriftConnectionFactory(InetSocketAddress endpoint, int maxConnections, TransportType transportType, Amount<Long, Time> socketTimeout) { this(endpoint, maxConnections, transportType, socketTimeout, Closures.<Connection<TTransport, InetSocketAddress>>noop(), false); } public ThriftConnectionFactory(InetSocketAddress endpoint, int maxConnections, TransportType transportType, Amount<Long, Time> socketTimeout, Closure<Connection<TTransport, InetSocketAddress>> postCreateCallback, boolean sslTransport) { Preconditions.checkArgument(maxConnections > 0, "maxConnections must be at least 1"); if (socketTimeout != null) { Preconditions.checkArgument(socketTimeout.as(Time.MILLISECONDS) >= 0); } this.endpoint = Preconditions.checkNotNull(endpoint); this.maxConnections = maxConnections; this.transportType = transportType; this.socketTimeout = socketTimeout; this.postCreateCallback = Preconditions.checkNotNull(postCreateCallback); this.sslTransport = sslTransport; } @Override public boolean mightCreate() { return lastActiveConnectionsSize < maxConnections; } /** * FIXME: shouldn't this throw TimeoutException instead of returning null * in the timeout cases as per the ConnectionFactory.create javadoc? */ @Override public Connection<TTransport, InetSocketAddress> create(Amount<Long, Time> timeout) throws TTransportException, IOException { Preconditions.checkNotNull(timeout); if (timeout.getValue() == 0) { return create(); } try { long timeRemainingNs = timeout.as(Time.NANOSECONDS); long start = System.nanoTime(); if(activeConnectionsWriteLock.tryLock(timeRemainingNs, TimeUnit.NANOSECONDS)) { try { if (!willCreateSafe()) { return null; } timeRemainingNs -= (System.nanoTime() - start); return createConnection((int) TimeUnit.NANOSECONDS.toMillis(timeRemainingNs)); } finally { activeConnectionsWriteLock.unlock(); } } else { return null; } } catch (InterruptedException e) { return null; } } private Connection<TTransport, InetSocketAddress> create() throws TTransportException, IOException { activeConnectionsWriteLock.lock(); try { if (!willCreateSafe()) { return null; } return createConnection(0); } finally { activeConnectionsWriteLock.unlock(); } } private Connection<TTransport, InetSocketAddress> createConnection(int timeoutMillis) throws TTransportException, IOException { TTransport transport = createTransport(timeoutMillis); if (transport == null) { return null; } Connection<TTransport, InetSocketAddress> connection = new TTransportConnection(transport, endpoint); postCreateCallback.execute(connection); activeConnections.add(connection); lastActiveConnectionsSize = activeConnections.size(); return connection; } private boolean willCreateSafe() { return activeConnections.size() < maxConnections; } @VisibleForTesting TTransport createTransport(int timeoutMillis) throws TTransportException, IOException { TSocket socket = null; if (transportType != TransportType.NONBLOCKING) { // can't do a nonblocking create on a blocking transport if (timeoutMillis <= 0) { return null; } if (sslTransport) { SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); SSLSocket ssl_socket = (SSLSocket) factory.createSocket(endpoint.getHostName(), endpoint.getPort()); ssl_socket.setSoTimeout(timeoutMillis); return new TSocket(ssl_socket); } else { socket = new TSocket(endpoint.getHostName(), endpoint.getPort(), timeoutMillis); } } try { switch (transportType) { case BLOCKING: socket.open(); setSocketTimeout(socket); return socket; case FRAMED: TFramedTransport transport = new TFramedTransport(socket); transport.open(); setSocketTimeout(socket); return transport; case NONBLOCKING: try { return new TNonblockingSocket(endpoint.getHostName(), endpoint.getPort()); } catch (IOException e) { throw new IOException("Failed to create non-blocking transport to " + endpoint, e); } } } catch (TTransportException e) { throw new TTransportException("Failed to create transport to " + endpoint, e); } throw new IllegalArgumentException("unknown transport type " + transportType); } private void setSocketTimeout(TSocket socket) { if (socketTimeout != null) { socket.setTimeout(socketTimeout.as(Time.MILLISECONDS).intValue()); } } @Override public void destroy(Connection<TTransport, InetSocketAddress> connection) { activeConnectionsWriteLock.lock(); try { boolean wasActiveConnection = activeConnections.remove(connection); Preconditions.checkArgument(wasActiveConnection, "connection %s not created by this factory", connection); lastActiveConnectionsSize = activeConnections.size(); } finally { activeConnectionsWriteLock.unlock(); } // We close the connection outside the critical section which means we may have more connections // "active" (open) than maxConnections for a very short time connection.close(); } @Override public String toString() { return String.format("%s[%s]", getClass().getSimpleName(), endpoint); } }