package com.lambdaworks.redis.cluster;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.*;
import java.util.stream.Collectors;
import com.lambdaworks.redis.RedisCommandExecutionException;
import com.lambdaworks.redis.RedisCommandInterruptedException;
import com.lambdaworks.redis.RedisCommandTimeoutException;
import com.lambdaworks.redis.api.StatefulRedisConnection;
import com.lambdaworks.redis.cluster.api.NodeSelectionSupport;
import com.lambdaworks.redis.cluster.api.async.RedisClusterAsyncCommands;
import com.lambdaworks.redis.cluster.models.partitions.RedisClusterNode;
import com.lambdaworks.redis.internal.AbstractInvocationHandler;
import com.lambdaworks.redis.internal.LettuceAssert;
/**
* Invocation handler to trigger commands on multiple connections and return a holder for the values.
*
* @author Mark Paluch
*/
class NodeSelectionInvocationHandler extends AbstractInvocationHandler {
private static final Method NULL_MARKER_METHOD;
private final Map<Method, Method> nodeSelectionMethods = new ConcurrentHashMap<>();
private final Map<Method, Method> connectionMethod = new ConcurrentHashMap<>();
private AbstractNodeSelection<?, ?, ?, ?> selection;
private boolean sync;
private long timeout;
private TimeUnit unit;
static {
try {
NULL_MARKER_METHOD = NodeSelectionInvocationHandler.class.getDeclaredMethod("handleInvocation", Object.class,
Method.class, Object[].class);
} catch (NoSuchMethodException e) {
throw new IllegalStateException(e);
}
}
public NodeSelectionInvocationHandler(AbstractNodeSelection<?, ?, ?, ?> selection) {
this(selection, false, 0, null);
}
public NodeSelectionInvocationHandler(AbstractNodeSelection<?, ?, ?, ?> selection, boolean sync, long timeout,
TimeUnit unit) {
if (sync) {
LettuceAssert.isTrue(timeout > 0, "Timeout must be greater 0 when using sync mode");
LettuceAssert.notNull(unit, "Unit must not be null when using sync mode");
}
this.selection = selection;
this.sync = sync;
this.unit = unit;
this.timeout = timeout;
}
@Override
@SuppressWarnings("rawtypes")
protected Object handleInvocation(Object proxy, Method method, Object[] args) throws Throwable {
try {
Method targetMethod = findMethod(RedisClusterAsyncCommands.class, method, connectionMethod);
Map<RedisClusterNode, StatefulRedisConnection<?, ?>> connections = new HashMap<>(selection.size(), 1);
connections.putAll(selection.statefulMap());
if (targetMethod != null) {
Map<RedisClusterNode, CompletionStage<?>> executions = new HashMap<>();
for (Map.Entry<RedisClusterNode, StatefulRedisConnection<?, ?>> entry : connections.entrySet()) {
CompletionStage<?> result = (CompletionStage<?>) targetMethod.invoke(entry.getValue().async(), args);
executions.put(entry.getKey(), result);
}
if (sync) {
if (!awaitAll(timeout, unit, executions.values())) {
throw createTimeoutException(executions);
}
if (atLeastOneFailed(executions)) {
throw createExecutionException(executions);
}
return new SyncExecutionsImpl(executions);
}
return new AsyncExecutionsImpl<>((Map) executions);
}
if (method.getName().equals("commands") && args.length == 0) {
return proxy;
}
targetMethod = findMethod(NodeSelectionSupport.class, method, nodeSelectionMethods);
return targetMethod.invoke(selection, args);
} catch (InvocationTargetException e) {
throw e.getTargetException();
}
}
public static boolean awaitAll(long timeout, TimeUnit unit, Collection<CompletionStage<?>> futures) {
boolean complete;
try {
long nanos = unit.toNanos(timeout);
long time = System.nanoTime();
for (CompletionStage<?> f : futures) {
if (nanos < 0) {
return false;
}
try {
f.toCompletableFuture().get(nanos, TimeUnit.NANOSECONDS);
} catch (ExecutionException e) {
// ignore
}
long now = System.nanoTime();
nanos -= now - time;
time = now;
}
complete = true;
} catch (TimeoutException e) {
complete = false;
} catch (Exception e) {
throw new RedisCommandInterruptedException(e);
}
return complete;
}
private boolean atLeastOneFailed(Map<RedisClusterNode, CompletionStage<?>> executions) {
return executions.values().stream()
.filter(completionStage -> completionStage.toCompletableFuture().isCompletedExceptionally()).findFirst()
.isPresent();
}
private RedisCommandTimeoutException createTimeoutException(Map<RedisClusterNode, CompletionStage<?>> executions) {
List<RedisClusterNode> notFinished = new ArrayList<>();
executions.forEach((redisClusterNode, completionStage) -> {
if (!completionStage.toCompletableFuture().isDone()) {
notFinished.add(redisClusterNode);
}
});
String description = getNodeDescription(notFinished);
return new RedisCommandTimeoutException("Command timed out for node(s): " + description);
}
private RedisCommandExecutionException createExecutionException(Map<RedisClusterNode, CompletionStage<?>> executions) {
List<RedisClusterNode> failed = new ArrayList<>();
executions.forEach((redisClusterNode, completionStage) -> {
if (!completionStage.toCompletableFuture().isCompletedExceptionally()) {
failed.add(redisClusterNode);
}
});
RedisCommandExecutionException e = new RedisCommandExecutionException(
"Multi-node command execution failed on node(s): " + getNodeDescription(failed));
executions.forEach((redisClusterNode, completionStage) -> {
CompletableFuture<?> completableFuture = completionStage.toCompletableFuture();
if (completableFuture.isCompletedExceptionally()) {
try {
completableFuture.get();
} catch (Exception innerException) {
if (innerException instanceof ExecutionException) {
e.addSuppressed(innerException.getCause());
} else {
e.addSuppressed(innerException);
}
}
}
});
return e;
}
private String getNodeDescription(List<RedisClusterNode> notFinished) {
return String.join(", ",
notFinished.stream().map(redisClusterNode -> getDescriptor(redisClusterNode)).collect(Collectors.toList()));
}
private String getDescriptor(RedisClusterNode redisClusterNode) {
StringBuffer buffer = new StringBuffer(redisClusterNode.getNodeId());
buffer.append(" (");
if (redisClusterNode.getUri() != null) {
buffer.append(redisClusterNode.getUri().getHost()).append(':').append(redisClusterNode.getUri().getPort());
}
buffer.append(')');
return buffer.toString();
}
private Method findMethod(Class<?> type, Method method, Map<Method, Method> cache) {
Method result = cache.get(method);
if (result != null && result != NULL_MARKER_METHOD) {
return result;
}
for (Method typeMethod : type.getMethods()) {
if (!typeMethod.getName().equals(method.getName())
|| !Arrays.equals(typeMethod.getParameterTypes(), method.getParameterTypes())) {
continue;
}
cache.put(method, typeMethod);
return typeMethod;
}
// Null-marker to avoid full class method scans.
cache.put(method, NULL_MARKER_METHOD);
return null;
}
}