package com.lambdaworks.redis.cluster; import static com.lambdaworks.redis.cluster.ClusterScanSupport.reactiveClusterKeyScanCursorMapper; import static com.lambdaworks.redis.cluster.ClusterScanSupport.reactiveClusterStreamScanCursorMapper; import static com.lambdaworks.redis.cluster.models.partitions.RedisClusterNode.NodeFlag.MASTER; import java.util.*; import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; import com.lambdaworks.redis.internal.LettuceLists; import rx.Observable; import com.lambdaworks.redis.*; import com.lambdaworks.redis.api.rx.RedisKeyReactiveCommands; import com.lambdaworks.redis.api.rx.RedisScriptingReactiveCommands; import com.lambdaworks.redis.api.rx.RedisServerReactiveCommands; import com.lambdaworks.redis.api.rx.Success; import com.lambdaworks.redis.cluster.api.StatefulRedisClusterConnection; import com.lambdaworks.redis.cluster.api.rx.RedisAdvancedClusterReactiveCommands; import com.lambdaworks.redis.cluster.api.rx.RedisClusterReactiveCommands; import com.lambdaworks.redis.cluster.models.partitions.Partitions; import com.lambdaworks.redis.cluster.models.partitions.RedisClusterNode; import com.lambdaworks.redis.codec.RedisCodec; import com.lambdaworks.redis.output.KeyStreamingChannel; import com.lambdaworks.redis.output.ValueStreamingChannel; /** * An advanced reactive and thread-safe API to a Redis Cluster connection. * * @author Mark Paluch * @since 4.0 */ public class RedisAdvancedClusterReactiveCommandsImpl<K, V> extends AbstractRedisReactiveCommands<K, V> implements RedisAdvancedClusterReactiveCommands<K, V> { private final Random random = new Random(); /** * Initialize a new connection. * * @param connection the stateful connection * @param codec Codec used to encode/decode keys and values. */ public RedisAdvancedClusterReactiveCommandsImpl(StatefulRedisClusterConnectionImpl<K, V> connection, RedisCodec<K, V> codec) { super(connection, codec); } @Override public Observable<Long> del(K... keys) { return del(Arrays.asList(keys)); } @Override public Observable<Long> del(Iterable<K> keys) { Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keys); if (partitioned.size() < 2) { return super.del(keys); } List<Observable<Long>> observables = new ArrayList<>(); for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { observables.add(super.del(entry.getValue())); } return Observable.merge(observables).reduce((accu, next) -> accu + next); } @Override public Observable<Long> unlink(K... keys) { return unlink(Arrays.asList(keys)); } @Override public Observable<Long> unlink(Iterable<K> keys) { Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keys); if (partitioned.size() < 2) { return super.unlink(keys); } List<Observable<Long>> observables = new ArrayList<>(); for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { observables.add(super.unlink(entry.getValue())); } return Observable.merge(observables).reduce((accu, next) -> accu + next); } @Override public Observable<Long> exists(K... keys) { return exists(Arrays.asList(keys)); } public Observable<Long> exists(Iterable<K> keys) { Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keys); if (partitioned.size() < 2) { return super.exists(keys); } List<Observable<Long>> observables = new ArrayList<>(); for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { observables.add(super.exists(entry.getValue())); } return Observable.merge(observables).reduce((accu, next) -> accu + next); } @Override public Observable<V> mget(K... keys) { return mget(Arrays.asList(keys)); } public Observable<V> mget(Iterable<K> keys) { List<K> keyList = LettuceLists.newList(keys); Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keyList); if (partitioned.size() < 2) { return super.mget(keyList); } List<Observable<V>> observables = new ArrayList<>(); for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { observables.add(super.mget(entry.getValue())); } Observable<V> observable = Observable.concat(Observable.from(observables)); Observable<List<V>> map = observable.toList().map(vs -> { Object[] values = new Object[vs.size()]; int offset = 0; for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { for (int i = 0; i < keyList.size(); i++) { int index = entry.getValue().indexOf(keyList.get(i)); if (index == -1) { continue; } values[i] = vs.get(offset + index); } offset += entry.getValue().size(); } List<V> objects = (List<V>) new ArrayList<>(Arrays.asList(values)); return objects; }); return map.compose(new FlattenTransform<>()); } @Override public Observable<Long> mget(ValueStreamingChannel<V> channel, K... keys) { return mget(channel, Arrays.asList(keys)); } public Observable<Long> mget(ValueStreamingChannel<V> channel, Iterable<K> keys) { List<K> keyList = LettuceLists.newList(keys); Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keyList); if (partitioned.size() < 2) { return super.mget(channel, keyList); } List<Observable<Long>> observables = new ArrayList<>(); for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { observables.add(super.mget(channel, entry.getValue())); } return Observable.merge(observables).reduce((accu, next) -> accu + next); } @Override public Observable<Boolean> msetnx(Map<K, V> map) { return pipeliningWithMap(map, kvMap -> super.msetnx(kvMap), booleanObservable -> booleanObservable.reduce((accu, next) -> accu && next)); } @Override public Observable<String> mset(Map<K, V> map) { return pipeliningWithMap(map, kvMap -> super.mset(kvMap), Observable::last); } @Override public Observable<K> clusterGetKeysInSlot(int slot, int count) { RedisClusterReactiveCommands<K, V> connectionBySlot = findConnectionBySlot(slot); if (connectionBySlot != null) { return connectionBySlot.clusterGetKeysInSlot(slot, count); } return super.clusterGetKeysInSlot(slot, count); } @Override public Observable<Long> clusterCountKeysInSlot(int slot) { RedisClusterReactiveCommands<K, V> connectionBySlot = findConnectionBySlot(slot); if (connectionBySlot != null) { return connectionBySlot.clusterCountKeysInSlot(slot); } return super.clusterCountKeysInSlot(slot); } @Override public Observable<String> clientSetname(K name) { List<Observable<String>> observables = new ArrayList<>(); for (RedisClusterNode redisClusterNode : getStatefulConnection().getPartitions()) { RedisClusterReactiveCommands<K, V> byNodeId = getConnection(redisClusterNode.getNodeId()); if (byNodeId.isOpen()) { observables.add(byNodeId.clientSetname(name)); } RedisClusterReactiveCommands<K, V> byHost = getConnection(redisClusterNode.getUri().getHost(), redisClusterNode .getUri().getPort()); if (byHost.isOpen()) { observables.add(byHost.clientSetname(name)); } } return Observable.merge(observables).last(); } @Override public Observable<Long> dbsize() { Map<String, Observable<Long>> observables = executeOnMasters(RedisServerReactiveCommands::dbsize); return Observable.merge(observables.values()).reduce((accu, next) -> accu + next); } @Override public Observable<String> flushall() { Map<String, Observable<String>> observables = executeOnMasters(RedisServerReactiveCommands::flushall); return Observable.merge(observables.values()).last(); } @Override public Observable<String> flushdb() { Map<String, Observable<String>> observables = executeOnMasters(RedisServerReactiveCommands::flushdb); return Observable.merge(observables.values()).last(); } @Override public Observable<K> keys(K pattern) { Map<String, Observable<K>> observables = executeOnMasters(commands -> commands.keys(pattern)); return Observable.merge(observables.values()); } @Override public Observable<Long> keys(KeyStreamingChannel<K> channel, K pattern) { Map<String, Observable<Long>> observables = executeOnMasters(commands -> commands.keys(channel, pattern)); return Observable.merge(observables.values()).reduce((accu, next) -> accu + next); } @Override public Observable<V> randomkey() { Partitions partitions = getStatefulConnection().getPartitions(); int index = random.nextInt(partitions.size()); RedisClusterReactiveCommands<K, V> connection = getConnection(partitions.getPartition(index).getNodeId()); return connection.randomkey(); } @Override public Observable<String> scriptFlush() { Map<String, Observable<String>> observables = executeOnNodes(RedisScriptingReactiveCommands::scriptFlush, redisClusterNode -> true); return Observable.merge(observables.values()).last(); } @Override public Observable<String> scriptKill() { Map<String, Observable<String>> observables = executeOnNodes(RedisScriptingReactiveCommands::scriptFlush, redisClusterNode -> true); return Observable.merge(observables.values()).onErrorReturn(throwable -> "OK").last(); } @Override public Observable<Success> shutdown(boolean save) { Map<String, Observable<Success>> observables = executeOnNodes(commands -> commands.shutdown(save), redisClusterNode -> true); return Observable.merge(observables.values()).onErrorReturn(throwable -> null).last(); } @Override public Observable<Long> touch(K... keys) { return touch(Arrays.asList(keys)); } public Observable<Long> touch(Iterable<K> keys) { List<K> keyList = LettuceLists.newList(keys); Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keyList); if (partitioned.size() < 2) { return super.touch(keyList); } List<Observable<Long>> observables = new ArrayList<>(); for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { observables.add(super.touch(entry.getValue())); } return Observable.merge(observables).reduce((accu, next) -> accu + next); } /** * Run a command on all available masters, * * @param function function producing the command * @param <T> result type * @return map of a key (counter) and commands. */ protected <T> Map<String, Observable<T>> executeOnMasters( Function<RedisClusterReactiveCommands<K, V>, Observable<T>> function) { return executeOnNodes(function, redisClusterNode -> redisClusterNode.is(MASTER)); } /** * Run a command on all available nodes that match {@code filter}. * * @param function function producing the command * @param filter filter function for the node selection * @param <T> result type * @return map of a key (counter) and commands. */ protected <T> Map<String, Observable<T>> executeOnNodes( Function<RedisClusterReactiveCommands<K, V>, Observable<T>> function, Function<RedisClusterNode, Boolean> filter) { Map<String, Observable<T>> executions = new HashMap<>(); for (RedisClusterNode redisClusterNode : getStatefulConnection().getPartitions()) { if (!filter.apply(redisClusterNode)) { continue; } RedisURI uri = redisClusterNode.getUri(); RedisClusterReactiveCommands<K, V> connection = getConnection(uri.getHost(), uri.getPort()); if (connection.isOpen()) { executions.put(redisClusterNode.getNodeId(), function.apply(connection)); } } return executions; } private RedisClusterReactiveCommands<K, V> findConnectionBySlot(int slot) { RedisClusterNode node = getStatefulConnection().getPartitions().getPartitionBySlot(slot); if (node != null) { return getConnection(node.getUri().getHost(), node.getUri().getPort()); } return null; } @Override public StatefulRedisClusterConnection<K, V> getStatefulConnection() { return (StatefulRedisClusterConnection<K, V>) connection; } @Override public RedisClusterReactiveCommands<K, V> getConnection(String nodeId) { return getStatefulConnection().getConnection(nodeId).reactive(); } @Override public RedisClusterReactiveCommands<K, V> getConnection(String host, int port) { return getStatefulConnection().getConnection(host, port).reactive(); } @Override public Observable<KeyScanCursor<K>> scan() { return clusterScan(ScanCursor.INITIAL, (connection, cursor) -> connection.scan(), reactiveClusterKeyScanCursorMapper()); } @Override public Observable<KeyScanCursor<K>> scan(ScanArgs scanArgs) { return clusterScan(ScanCursor.INITIAL, (connection, cursor) -> connection.scan(scanArgs), reactiveClusterKeyScanCursorMapper()); } @Override public Observable<KeyScanCursor<K>> scan(ScanCursor scanCursor, ScanArgs scanArgs) { return clusterScan(scanCursor, (connection, cursor) -> connection.scan(cursor, scanArgs), reactiveClusterKeyScanCursorMapper()); } @Override public Observable<KeyScanCursor<K>> scan(ScanCursor scanCursor) { return clusterScan(scanCursor, (connection, cursor) -> connection.scan(cursor), reactiveClusterKeyScanCursorMapper()); } @Override public Observable<StreamScanCursor> scan(KeyStreamingChannel<K> channel) { return clusterScan(ScanCursor.INITIAL, (connection, cursor) -> connection.scan(channel), reactiveClusterStreamScanCursorMapper()); } @Override public Observable<StreamScanCursor> scan(KeyStreamingChannel<K> channel, ScanArgs scanArgs) { return clusterScan(ScanCursor.INITIAL, (connection, cursor) -> connection.scan(channel, scanArgs), reactiveClusterStreamScanCursorMapper()); } @Override public Observable<StreamScanCursor> scan(KeyStreamingChannel<K> channel, ScanCursor scanCursor, ScanArgs scanArgs) { return clusterScan(scanCursor, (connection, cursor) -> connection.scan(channel, cursor, scanArgs), reactiveClusterStreamScanCursorMapper()); } @Override public Observable<StreamScanCursor> scan(KeyStreamingChannel<K> channel, ScanCursor scanCursor) { return clusterScan(scanCursor, (connection, cursor) -> connection.scan(channel, cursor), reactiveClusterStreamScanCursorMapper()); } private <T extends ScanCursor> Observable<T> clusterScan(ScanCursor cursor, BiFunction<RedisKeyReactiveCommands<K, V>, ScanCursor, Observable<T>> scanFunction, ClusterScanSupport.ScanCursorMapper<Observable<T>> resultMapper) { return clusterScan(getStatefulConnection(), cursor, scanFunction, (ClusterScanSupport.ScanCursorMapper) resultMapper); } /** * Perform a SCAN in the cluster. * */ static <T extends ScanCursor, K, V> Observable<T> clusterScan(StatefulRedisClusterConnection<K, V> connection, ScanCursor cursor, BiFunction<RedisKeyReactiveCommands<K, V>, ScanCursor, Observable<T>> scanFunction, ClusterScanSupport.ScanCursorMapper<Observable<T>> mapper) { List<String> nodeIds = ClusterScanSupport.getNodeIds(connection, cursor); String currentNodeId = ClusterScanSupport.getCurrentNodeId(cursor, nodeIds); ScanCursor continuationCursor = ClusterScanSupport.getContinuationCursor(cursor); Observable<T> scanCursor = scanFunction.apply(connection.getConnection(currentNodeId).reactive(), continuationCursor); return mapper.map(nodeIds, currentNodeId, scanCursor); } private <T> Observable<T> pipeliningWithMap(Map<K, V> map, Function<Map<K, V>, Observable<T>> function, Function<Observable<T>, Observable<T>> resultFunction) { Map<Integer, List<K>> partitioned = SlotHash.partition(codec, map.keySet()); if (partitioned.size() < 2) { return function.apply(map); } List<Observable<T>> observables = partitioned.values().stream().map(ks -> { Map<K, V> op = new HashMap<>(); ks.forEach(k -> op.put(k, map.get(k))); return function.apply(op); }).collect(Collectors.toList()); return resultFunction.apply(Observable.merge(observables)); } static class FlattenTransform<T> implements Observable.Transformer<Iterable<T>, T> { @Override public Observable<T> call(Observable<Iterable<T>> source) { return source.flatMap(values -> Observable.from(values)); } } }