/* * Copyright 2011-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.lambdaworks.redis.cluster; import static com.lambdaworks.redis.cluster.ClusterScanSupport.reactiveClusterKeyScanCursorMapper; import static com.lambdaworks.redis.cluster.ClusterScanSupport.reactiveClusterStreamScanCursorMapper; import static com.lambdaworks.redis.cluster.models.partitions.RedisClusterNode.NodeFlag.MASTER; import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ThreadLocalRandom; import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; import rx.Observable; import rx.Single; import com.lambdaworks.redis.*; import com.lambdaworks.redis.api.StatefulRedisConnection; import com.lambdaworks.redis.api.rx.RedisKeyReactiveCommands; import com.lambdaworks.redis.api.rx.RedisScriptingReactiveCommands; import com.lambdaworks.redis.api.rx.RedisServerReactiveCommands; import com.lambdaworks.redis.api.rx.Success; import com.lambdaworks.redis.cluster.ClusterConnectionProvider.Intent; import com.lambdaworks.redis.cluster.api.rx.RedisAdvancedClusterReactiveCommands; import com.lambdaworks.redis.cluster.api.rx.RedisClusterReactiveCommands; import com.lambdaworks.redis.cluster.models.partitions.Partitions; import com.lambdaworks.redis.cluster.models.partitions.RedisClusterNode; import com.lambdaworks.redis.codec.RedisCodec; import com.lambdaworks.redis.internal.LettuceLists; import com.lambdaworks.redis.output.KeyStreamingChannel; import com.lambdaworks.redis.output.ValueStreamingChannel; /** * An advanced reactive and thread-safe API to a Redis Cluster connection. * * @author Mark Paluch * @since 4.0 */ public class RedisAdvancedClusterReactiveCommandsImpl<K, V> extends AbstractRedisReactiveCommands<K, V> implements RedisAdvancedClusterReactiveCommands<K, V> { /** * Initialize a new connection. * * @param connection the stateful connection * @param codec Codec used to encode/decode keys and values. */ public RedisAdvancedClusterReactiveCommandsImpl(StatefulRedisClusterConnectionImpl<K, V> connection, RedisCodec<K, V> codec) { super(connection, codec); } @Override public Observable<Long> del(K... keys) { return del(Arrays.asList(keys)); } @Override public Observable<Long> del(Iterable<K> keys) { Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keys); if (partitioned.size() < 2) { return super.del(keys); } List<Observable<Long>> observables = new ArrayList<>(); for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { observables.add(super.del(entry.getValue())); } return Observable.merge(observables).reduce((accu, next) -> accu + next); } @Override public Observable<Long> unlink(K... keys) { return unlink(Arrays.asList(keys)); } @Override public Observable<Long> unlink(Iterable<K> keys) { Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keys); if (partitioned.size() < 2) { return super.unlink(keys); } List<Observable<Long>> observables = new ArrayList<>(); for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { observables.add(super.unlink(entry.getValue())); } return Observable.merge(observables).reduce((accu, next) -> accu + next); } @Override public Observable<Long> exists(K... keys) { return exists(Arrays.asList(keys)); } public Observable<Long> exists(Iterable<K> keys) { Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keys); if (partitioned.size() < 2) { return super.exists(keys); } List<Observable<Long>> observables = new ArrayList<>(); for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { observables.add(super.exists(entry.getValue())); } return Observable.merge(observables).reduce((accu, next) -> accu + next); } @Override public Observable<V> mget(K... keys) { return mget(Arrays.asList(keys)); } public Observable<V> mget(Iterable<K> keys) { List<K> keyList = LettuceLists.newList(keys); Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keyList); if (partitioned.size() < 2) { return super.mget(keyList); } List<Observable<V>> observables = new ArrayList<>(); for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { observables.add(super.mget(entry.getValue())); } Observable<V> observable = Observable.concat(Observable.from(observables)); Observable<List<V>> map = observable.toList().map(vs -> { Object[] values = new Object[vs.size()]; int offset = 0; for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { for (int i = 0; i < keyList.size(); i++) { int index = entry.getValue().indexOf(keyList.get(i)); if (index == -1) { continue; } values[i] = vs.get(offset + index); } offset += entry.getValue().size(); } List<V> objects = (List<V>) new ArrayList<>(Arrays.asList(values)); return objects; }); return map.compose(new FlattenTransform<>()); } @Override public Observable<Long> mget(ValueStreamingChannel<V> channel, K... keys) { return mget(channel, Arrays.asList(keys)); } public Observable<Long> mget(ValueStreamingChannel<V> channel, Iterable<K> keys) { List<K> keyList = LettuceLists.newList(keys); Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keyList); if (partitioned.size() < 2) { return super.mget(channel, keyList); } List<Observable<Long>> observables = new ArrayList<>(); for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { observables.add(super.mget(channel, entry.getValue())); } return Observable.merge(observables).reduce((accu, next) -> accu + next); } @Override public Observable<Boolean> msetnx(Map<K, V> map) { return pipeliningWithMap(map, super::msetnx, booleanObservable -> booleanObservable.reduce((accu, next) -> accu && next)); } @Override public Observable<String> mset(Map<K, V> map) { return pipeliningWithMap(map, super::mset, Observable::last); } @Override public Observable<K> clusterGetKeysInSlot(int slot, int count) { Single<RedisClusterReactiveCommands<K, V>> connectionBySlot = findConnectionBySlotReactive(slot); return connectionBySlot.flatMapObservable(conn -> conn.clusterGetKeysInSlot(slot, count)); } @Override public Observable<Long> clusterCountKeysInSlot(int slot) { Single<RedisClusterReactiveCommands<K, V>> connectionBySlot = findConnectionBySlotReactive(slot); return connectionBySlot.flatMapObservable(cmd -> cmd.clusterCountKeysInSlot(slot)); } @Override public Observable<String> clientSetname(K name) { List<Observable<String>> observables = new ArrayList<>(); for (RedisClusterNode redisClusterNode : getStatefulConnection().getPartitions()) { Single<RedisClusterReactiveCommands<K, V>> byNodeId = getConnectionReactive(redisClusterNode.getNodeId()); observables.add(byNodeId.flatMapObservable(conn -> { if (conn.isOpen()) { return conn.clientSetname(name); } return Observable.empty(); })); Single<RedisClusterReactiveCommands<K, V>> byHost = getConnectionReactive(redisClusterNode.getUri().getHost(), redisClusterNode.getUri().getPort()); observables.add(byHost.flatMapObservable(conn -> { if (conn.isOpen()) { return conn.clientSetname(name); } return Observable.empty(); })); } return Observable.merge(observables).last(); } @Override public Observable<Long> dbsize() { Map<String, Observable<Long>> observables = executeOnMasters(RedisServerReactiveCommands::dbsize); return Observable.merge(observables.values()).reduce((accu, next) -> accu + next); } @Override public Observable<String> flushall() { Map<String, Observable<String>> observables = executeOnMasters(RedisServerReactiveCommands::flushall); return Observable.merge(observables.values()).last(); } @Override public Observable<String> flushdb() { Map<String, Observable<String>> observables = executeOnMasters(RedisServerReactiveCommands::flushdb); return Observable.merge(observables.values()).last(); } @Override public Observable<K> keys(K pattern) { Map<String, Observable<K>> observables = executeOnMasters(commands -> commands.keys(pattern)); return Observable.merge(observables.values()); } @Override public Observable<Long> keys(KeyStreamingChannel<K> channel, K pattern) { Map<String, Observable<Long>> observables = executeOnMasters(commands -> commands.keys(channel, pattern)); return Observable.merge(observables.values()).reduce((accu, next) -> accu + next); } @Override public Observable<V> randomkey() { Partitions partitions = getStatefulConnection().getPartitions(); int index = ThreadLocalRandom.current().nextInt(partitions.size()); Single<RedisClusterReactiveCommands<K, V>> connection = getConnectionReactive(partitions.getPartition(index) .getNodeId()); return connection.flatMapObservable(RedisKeyReactiveCommands::randomkey); } @Override public Observable<String> scriptFlush() { Map<String, Observable<String>> observables = executeOnNodes(RedisScriptingReactiveCommands::scriptFlush, redisClusterNode -> true); return Observable.merge(observables.values()).last(); } @Override public Observable<String> scriptKill() { Map<String, Observable<String>> observables = executeOnNodes(RedisScriptingReactiveCommands::scriptFlush, redisClusterNode -> true); return Observable.merge(observables.values()).onErrorReturn(throwable -> "OK").last(); } @Override public Observable<Success> shutdown(boolean save) { Map<String, Observable<Success>> observables = executeOnNodes(commands -> commands.shutdown(save), redisClusterNode -> true); return Observable.merge(observables.values()).onErrorReturn(throwable -> null).last(); } @Override public Observable<Long> touch(K... keys) { return touch(Arrays.asList(keys)); } public Observable<Long> touch(Iterable<K> keys) { List<K> keyList = LettuceLists.newList(keys); Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keyList); if (partitioned.size() < 2) { return super.touch(keyList); } List<Observable<Long>> observables = new ArrayList<>(); for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) { observables.add(super.touch(entry.getValue())); } return Observable.merge(observables).reduce((accu, next) -> accu + next); } /** * Run a command on all available masters, * * @param function function producing the command * @param <T> result type * @return map of a key (counter) and commands. */ protected <T> Map<String, Observable<T>> executeOnMasters( Function<RedisClusterReactiveCommands<K, V>, Observable<T>> function) { return executeOnNodes(function, redisClusterNode -> redisClusterNode.is(MASTER)); } /** * Run a command on all available nodes that match {@code filter}. * * @param function function producing the command * @param filter filter function for the node selection * @param <T> result type * @return map of a key (counter) and commands. */ protected <T> Map<String, Observable<T>> executeOnNodes( Function<RedisClusterReactiveCommands<K, V>, Observable<T>> function, Function<RedisClusterNode, Boolean> filter) { Map<String, Observable<T>> executions = new HashMap<>(); for (RedisClusterNode redisClusterNode : getStatefulConnection().getPartitions()) { if (!filter.apply(redisClusterNode)) { continue; } RedisURI uri = redisClusterNode.getUri(); Single<RedisClusterReactiveCommands<K, V>> connection = getConnectionReactive(uri.getHost(), uri.getPort()); executions.put(redisClusterNode.getNodeId(), connection.flatMapObservable(function::apply)); } return executions; } private Single<RedisClusterReactiveCommands<K, V>> findConnectionBySlotReactive(int slot) { RedisClusterNode node = getStatefulConnection().getPartitions().getPartitionBySlot(slot); if (node != null) { return getConnectionReactive(node.getUri().getHost(), node.getUri().getPort()); } return Single.error(new RedisException("No partition for slot " + slot)); } @Override public StatefulRedisClusterConnectionImpl<K, V> getStatefulConnection() { return (StatefulRedisClusterConnectionImpl<K, V>) connection; } @Override public RedisClusterReactiveCommands<K, V> getConnection(String nodeId) { return getStatefulConnection().getConnection(nodeId).reactive(); } private Single<RedisClusterReactiveCommands<K, V>> getConnectionReactive(String nodeId) { return getSingle(getConnectionProvider().<K, V> getConnectionAsync(Intent.WRITE, nodeId)).map( StatefulRedisConnection::reactive); } @Override public RedisClusterReactiveCommands<K, V> getConnection(String host, int port) { return getStatefulConnection().getConnection(host, port).reactive(); } private Single<RedisClusterReactiveCommands<K, V>> getConnectionReactive(String host, int port) { return getSingle(getConnectionProvider().<K, V> getConnectionAsync(Intent.WRITE, host, port)).map( StatefulRedisConnection::reactive); } private AsyncClusterConnectionProvider getConnectionProvider() { return (AsyncClusterConnectionProvider) getStatefulConnection().getClusterDistributionChannelWriter() .getClusterConnectionProvider(); } @Override public Observable<KeyScanCursor<K>> scan() { return clusterScan(ScanCursor.INITIAL, (connection, cursor) -> connection.scan(), reactiveClusterKeyScanCursorMapper()); } @Override public Observable<KeyScanCursor<K>> scan(ScanArgs scanArgs) { return clusterScan(ScanCursor.INITIAL, (connection, cursor) -> connection.scan(scanArgs), reactiveClusterKeyScanCursorMapper()); } @Override public Observable<KeyScanCursor<K>> scan(ScanCursor scanCursor, ScanArgs scanArgs) { return clusterScan(scanCursor, (connection, cursor) -> connection.scan(cursor, scanArgs), reactiveClusterKeyScanCursorMapper()); } @Override public Observable<KeyScanCursor<K>> scan(ScanCursor scanCursor) { return clusterScan(scanCursor, RedisKeyReactiveCommands::scan, reactiveClusterKeyScanCursorMapper()); } @Override public Observable<StreamScanCursor> scan(KeyStreamingChannel<K> channel) { return clusterScan(ScanCursor.INITIAL, (connection, cursor) -> connection.scan(channel), reactiveClusterStreamScanCursorMapper()); } @Override public Observable<StreamScanCursor> scan(KeyStreamingChannel<K> channel, ScanArgs scanArgs) { return clusterScan(ScanCursor.INITIAL, (connection, cursor) -> connection.scan(channel, scanArgs), reactiveClusterStreamScanCursorMapper()); } @Override public Observable<StreamScanCursor> scan(KeyStreamingChannel<K> channel, ScanCursor scanCursor, ScanArgs scanArgs) { return clusterScan(scanCursor, (connection, cursor) -> connection.scan(channel, cursor, scanArgs), reactiveClusterStreamScanCursorMapper()); } @Override public Observable<StreamScanCursor> scan(KeyStreamingChannel<K> channel, ScanCursor scanCursor) { return clusterScan(scanCursor, (connection, cursor) -> connection.scan(channel, cursor), reactiveClusterStreamScanCursorMapper()); } @SuppressWarnings("unchecked") private <T extends ScanCursor> Observable<T> clusterScan(ScanCursor cursor, BiFunction<RedisKeyReactiveCommands<K, V>, ScanCursor, Observable<T>> scanFunction, ClusterScanSupport.ScanCursorMapper<Observable<T>> resultMapper) { return clusterScan(getStatefulConnection(), cursor, scanFunction, (ClusterScanSupport.ScanCursorMapper) resultMapper); } /** * Perform a SCAN in the cluster. * */ static <T extends ScanCursor, K, V> Observable<T> clusterScan(StatefulRedisClusterConnectionImpl<K, V> connection, ScanCursor cursor, BiFunction<RedisKeyReactiveCommands<K, V>, ScanCursor, Observable<T>> scanFunction, ClusterScanSupport.ScanCursorMapper<Observable<T>> mapper) { List<String> nodeIds = ClusterScanSupport.getNodeIds(connection, cursor); String currentNodeId = ClusterScanSupport.getCurrentNodeId(cursor, nodeIds); ScanCursor continuationCursor = ClusterScanSupport.getContinuationCursor(cursor); AsyncClusterConnectionProvider connectionProvider = (AsyncClusterConnectionProvider) connection .getClusterDistributionChannelWriter().getClusterConnectionProvider(); Observable<T> scanCursor = getSingle(connectionProvider.<K, V> getConnectionAsync(Intent.WRITE, currentNodeId)) .flatMapObservable(conn -> scanFunction.apply(conn.reactive(), continuationCursor)); return mapper.map(nodeIds, currentNodeId, scanCursor); } private <T> Observable<T> pipeliningWithMap(Map<K, V> map, Function<Map<K, V>, Observable<T>> function, Function<Observable<T>, Observable<T>> resultFunction) { Map<Integer, List<K>> partitioned = SlotHash.partition(codec, map.keySet()); if (partitioned.size() < 2) { return function.apply(map); } List<Observable<T>> observables = partitioned.values().stream().map(ks -> { Map<K, V> op = new HashMap<>(); ks.forEach(k -> op.put(k, map.get(k))); return function.apply(op); }).collect(Collectors.toList()); return resultFunction.apply(Observable.merge(observables)); } static class FlattenTransform<T> implements Observable.Transformer<Iterable<T>, T> { @Override public Observable<T> call(Observable<Iterable<T>> source) { return source.flatMap(Observable::from); } } private static <T> Single<T> getSingle(CompletableFuture<T> future) { return Single.create(singleSubscriber -> { future.whenComplete((connection, throwable) -> { if (throwable != null) { singleSubscriber.onError(throwable); } else { singleSubscriber.onSuccess(connection); } }); }); } }