// =================================================================================================
// Copyright 2011 Twitter, Inc.
// -------------------------------------------------------------------------------------------------
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this work except in compliance with the License.
// You may obtain a copy of the License in the LICENSE file, or at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =================================================================================================
package com.twitter.common.thrift.callers;
import com.google.common.base.Function;
import com.google.common.collect.Lists;
import com.twitter.common.net.pool.Connection;
import com.twitter.common.net.pool.ObjectPool;
import com.twitter.common.quantity.Amount;
import com.twitter.common.quantity.Time;
import com.twitter.common.net.pool.ResourceExhaustedException;
import com.twitter.common.thrift.TResourceExhaustedException;
import com.twitter.common.thrift.TTimeoutException;
import com.twitter.common.net.loadbalancing.RequestTracker;
import org.apache.thrift.async.AsyncMethodCallback;
import org.apache.thrift.transport.TTransport;
import javax.annotation.Nullable;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.TimeoutException;
import java.util.logging.Logger;
/**
* A caller that issues calls to a target that is assumed to be a client to a thrift service.
*
* @author William Farner
*/
public class ThriftCaller<T> implements Caller {
private static final Logger LOG = Logger.getLogger(ThriftCaller.class.getName());
private final ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool;
private final RequestTracker<InetSocketAddress> requestTracker;
private final Function<TTransport, T> clientFactory;
private final Amount<Long, Time> timeout;
private final boolean debug;
/**
* Creates a new thrift caller.
*
* @param connectionPool The connection pool to use.
* @param requestTracker The request tracker to nofify of request results.
* @param clientFactory Factory to use for building client object instances.
* @param timeout The timeout to use when requesting objects from the connection pool.
* @param debug Whether to use the caller in debug mode.
*/
public ThriftCaller(ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool,
RequestTracker<InetSocketAddress> requestTracker, Function<TTransport, T> clientFactory,
Amount<Long, Time> timeout, boolean debug) {
this.connectionPool = connectionPool;
this.requestTracker = requestTracker;
this.clientFactory = clientFactory;
this.timeout = timeout;
this.debug = debug;
}
@Override
public Object call(Method method, Object[] args, @Nullable AsyncMethodCallback callback,
@Nullable Amount<Long, Time> connectTimeoutOverride) throws Throwable {
final Connection<TTransport, InetSocketAddress> connection = getConnection(connectTimeoutOverride);
final long startNanos = System.nanoTime();
ResultCapture capture = new ResultCapture() {
@Override public void success() {
try {
requestTracker.requestResult(connection.getEndpoint(),
RequestTracker.RequestResult.SUCCESS, System.nanoTime() - startNanos);
} finally {
connectionPool.release(connection);
}
}
@Override public boolean fail(Throwable t) {
if (debug) {
LOG.warning(String.format("Call to endpoint: %s failed: %s", connection, t));
}
try {
requestTracker.requestResult(connection.getEndpoint(),
RequestTracker.RequestResult.FAILED, System.nanoTime() - startNanos);
} finally {
connectionPool.remove(connection);
}
return true;
}
};
return invokeMethod(clientFactory.apply(connection.get()), method, args, callback, capture);
}
private static Object invokeMethod(Object target, Method method, Object[] args,
AsyncMethodCallback callback, final ResultCapture capture) throws Throwable {
// Swap the wrapped callback out for ours.
if (callback != null) {
callback = new WrappedMethodCallback(callback, capture);
List<Object> argsList = Lists.newArrayList(args);
argsList.add(callback);
args = argsList.toArray();
}
try {
Object result = method.invoke(target, args);
if (callback == null) capture.success();
return result;
} catch (InvocationTargetException t) {
// We allow this one to go to both sync and async captures.
if (callback != null) {
callback.onError(t.getCause());
return null;
} else {
capture.fail(t.getCause());
throw t.getCause();
}
}
}
private Connection<TTransport, InetSocketAddress> getConnection(
Amount<Long, Time> connectTimeoutOverride)
throws TResourceExhaustedException, TTimeoutException {
try {
Connection<TTransport, InetSocketAddress> connection;
if (connectTimeoutOverride != null) {
connection = connectionPool.get(connectTimeoutOverride);
} else {
connection = (timeout.getValue() > 0)
? connectionPool.get(timeout) : connectionPool.get();
}
if (connection == null) {
throw new TResourceExhaustedException("no connection was available");
}
return connection;
} catch (ResourceExhaustedException e) {
throw new TResourceExhaustedException(e);
} catch (TimeoutException e) {
throw new TTimeoutException(e);
}
}
}