package com.lambdaworks.redis.masterslave;
import static com.lambdaworks.redis.masterslave.MasterSlaveUtils.findNodeByUri;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import com.lambdaworks.redis.RedisClient;
import com.lambdaworks.redis.RedisCommandInterruptedException;
import com.lambdaworks.redis.RedisFuture;
import com.lambdaworks.redis.RedisURI;
import com.lambdaworks.redis.api.StatefulRedisConnection;
import com.lambdaworks.redis.cluster.models.partitions.Partitions;
import com.lambdaworks.redis.models.role.RedisNodeDescription;
import com.lambdaworks.redis.output.StatusOutput;
import com.lambdaworks.redis.protocol.*;
import io.netty.buffer.ByteBuf;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
/**
* Utility to refresh the Master-Slave topology view based on {@link RedisNodeDescription}.
*
* @author Mark Paluch
*/
class MasterSlaveTopologyRefresh {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(MasterSlaveTopologyRefresh.class);
private final RedisClient client;
private final TopologyProvider topologyProvider;
public MasterSlaveTopologyRefresh(RedisClient client, TopologyProvider topologyProvider) {
this.client = client;
this.topologyProvider = topologyProvider;
}
/**
* Load master slave nodes. Result contains an ordered list of {@link RedisNodeDescription}s. The sort key is the latency.
* Nodes with lower latency come first.
*
* @param seed collection of {@link RedisURI}s
* @return mapping between {@link RedisURI} and {@link Partitions}
*/
public List<RedisNodeDescription> getNodes(RedisURI seed) {
List<RedisNodeDescription> nodes = topologyProvider.getNodes();
addPasswordIfNeeded(nodes, seed);
Map<RedisURI, StatefulRedisConnection<String, String>> connections = getConnections(nodes);
Map<RedisURI, TimedAsyncCommand<String, String, String>> rawViews = requestPing(connections);
List<RedisNodeDescription> result = getNodeSpecificViews(rawViews, nodes, seed);
close(connections);
return result;
}
private void addPasswordIfNeeded(List<RedisNodeDescription> nodes, RedisURI seed) {
if (seed.getPassword() != null && seed.getPassword().length != 0) {
for (RedisNodeDescription node : nodes) {
node.getUri().setPassword(new String(seed.getPassword()));
}
}
}
protected List<RedisNodeDescription> getNodeSpecificViews(
Map<RedisURI, TimedAsyncCommand<String, String, String>> rawViews, List<RedisNodeDescription> nodes, RedisURI seed) {
List<RedisNodeDescription> result = new ArrayList<>();
long timeout = seed.getUnit().toNanos(seed.getTimeout());
long waitTime = 0;
Map<RedisNodeDescription, Long> latencies = new HashMap<>();
for (Map.Entry<RedisURI, TimedAsyncCommand<String, String, String>> entry : rawViews.entrySet()) {
long timeoutLeft = timeout - waitTime;
if (timeoutLeft <= 0) {
break;
}
long startWait = System.nanoTime();
RedisFuture<String> future = entry.getValue();
try {
if (!future.await(timeoutLeft, TimeUnit.NANOSECONDS)) {
break;
}
waitTime += System.nanoTime() - startWait;
future.get();
RedisNodeDescription redisNodeDescription = findNodeByUri(nodes, entry.getKey());
latencies.put(redisNodeDescription, entry.getValue().duration());
result.add(redisNodeDescription);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RedisCommandInterruptedException(e);
} catch (ExecutionException e) {
logger.warn("Cannot retrieve partition view from " + entry.getKey(), e);
}
}
LatencyComparator comparator = new LatencyComparator(latencies);
Collections.sort(result, comparator);
return result;
}
/*
* Async request of views.
*/
@SuppressWarnings("unchecked")
private Map<RedisURI, TimedAsyncCommand<String, String, String>> requestPing(
Map<RedisURI, StatefulRedisConnection<String, String>> connections) {
Map<RedisURI, TimedAsyncCommand<String, String, String>> rawViews = new TreeMap<>(RedisUriComparator.INSTANCE);
for (Map.Entry<RedisURI, StatefulRedisConnection<String, String>> entry : connections.entrySet()) {
TimedAsyncCommand<String, String, String> timed = createPingCommand();
entry.getValue().dispatch(timed);
rawViews.put(entry.getKey(), timed);
}
return rawViews;
}
protected TimedAsyncCommand<String, String, String> createPingCommand() {
CommandArgs<String, String> args = new CommandArgs<>(MasterSlaveUtils.CODEC);
Command<String, String, String> command = new Command<>(CommandType.PING, new StatusOutput<>(MasterSlaveUtils.CODEC),
args);
return new TimedAsyncCommand<>(command);
}
private void close(Map<RedisURI, StatefulRedisConnection<String, String>> connections) {
for (StatefulRedisConnection<String, String> connection : connections.values()) {
connection.close();
}
}
/*
* Open connections where an address can be resolved.
*/
private Map<RedisURI, StatefulRedisConnection<String, String>> getConnections(Iterable<RedisNodeDescription> nodes) {
Map<RedisURI, StatefulRedisConnection<String, String>> connections = new TreeMap<>(RedisUriComparator.INSTANCE);
for (RedisNodeDescription node : nodes) {
try {
StatefulRedisConnection<String, String> connection = client.connect(node.getUri());
connections.put(node.getUri(), connection);
} catch (RuntimeException e) {
logger.warn("Cannot connect to " + node.getUri(), e);
}
}
return connections;
}
/**
* Compare {@link RedisURI} based on their host and port representation.
*/
static class RedisUriComparator implements Comparator<RedisURI> {
public static final RedisUriComparator INSTANCE = new RedisUriComparator();
@Override
public int compare(RedisURI o1, RedisURI o2) {
String h1 = "";
String h2 = "";
if (o1 != null) {
h1 = o1.getHost() + ":" + o1.getPort();
}
if (o2 != null) {
h2 = o2.getHost() + ":" + o2.getPort();
}
return h1.compareToIgnoreCase(h2);
}
}
/**
* Timed command that records the time at which the command was encoded and completed.
*
* @param <K> Key type
* @param <V> Value type
* @param <T> Result type
*/
static class TimedAsyncCommand<K, V, T> extends AsyncCommand<K, V, T> {
long encodedAtNs = -1;
long completedAtNs = -1;
public TimedAsyncCommand(RedisCommand<K, V, T> command) {
super(command);
}
@Override
public void encode(ByteBuf buf) {
completedAtNs = -1;
encodedAtNs = -1;
super.encode(buf);
encodedAtNs = System.nanoTime();
}
@Override
public void complete() {
completedAtNs = System.nanoTime();
super.complete();
}
public long duration() {
if (completedAtNs == -1 || encodedAtNs == -1) {
return -1;
}
return completedAtNs - encodedAtNs;
}
}
static class LatencyComparator implements Comparator<RedisNodeDescription> {
private final Map<RedisNodeDescription, Long> latencies;
public LatencyComparator(Map<RedisNodeDescription, Long> latencies) {
this.latencies = latencies;
}
@Override
public int compare(RedisNodeDescription o1, RedisNodeDescription o2) {
Long latency1 = latencies.get(o1);
Long latency2 = latencies.get(o2);
if (latency1 != null && latency2 != null) {
return latency1.compareTo(latency2);
}
if (latency1 != null && latency2 == null) {
return -1;
}
if (latency1 == null && latency2 != null) {
return 1;
}
return 0;
}
}
}