package org.infinispan.query.affinity;
import static java.util.stream.Collectors.toList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;
import org.infinispan.commons.util.CollectionFactory;
import org.infinispan.distribution.ch.ConsistentHash;
import org.infinispan.query.logging.Log;
import org.infinispan.remoting.transport.Address;
import org.infinispan.util.logging.LogFactory;
/**
* {@link ShardDistribution} that maintain a fixed number of index shards.
* The minimum number of shards is 1 and the maximum is the number of segments.
*
* @since 9.0
*/
class FixedShardsDistribution implements ShardDistribution {
private static final Log LOGGER = LogFactory.getLog(FixedShardsDistribution.class, Log.class);
private final Map<Integer, String> shardPerSegmentMap = CollectionFactory.makeConcurrentMap();
private final Map<Address, Set<String>> shardsPerAddressMap = CollectionFactory.makeConcurrentMap();
private final Map<String, Address> addressPerShardMap = CollectionFactory.makeConcurrentMap();
private final int numShards;
FixedShardsDistribution(ConsistentHash consistentHash, int numShards) {
if (numShards > consistentHash.getNumSegments()) {
throw new IllegalArgumentException("Number of shards cannot be higher than number of segments");
}
if (numShards < 0) {
throw new IllegalArgumentException("Number of shards cannot be negative");
}
this.numShards = numShards;
this.calculate(consistentHash, numShards);
}
private void calculate(ConsistentHash consistentHash, int numShards) {
List<Address> nodes = consistentHash.getMembers();
int numNodes = nodes.size();
List<Set<Integer>> segmentsPerServer = nodes.stream()
.map(consistentHash::getPrimarySegmentsForOwner).collect(toList());
int[] shardsNumPerServer = allocateShardsToNodes(numShards, numNodes, segmentsPerServer);
this.populateSegments(shardsNumPerServer, segmentsPerServer, nodes);
LOGGER.tracef("Calculated shard distribution shardPerSegmentMap: %s", shardPerSegmentMap);
LOGGER.tracef("Calculated shard distribution shardsPerAddressMap: %s", shardsPerAddressMap);
LOGGER.tracef("Calculated shard distribution addressPerShardMap: %s", addressPerShardMap);
}
/**
* Associates segments to each shard.
*
* @param shardsNumPerServer numbers of shards allocated for each server
* @param segmentsPerServer the primary owned segments of each server
* @param nodes the members of the cluster
*/
private void populateSegments(int[] shardsNumPerServer, List<Set<Integer>> segmentsPerServer, List<Address> nodes) {
int shardId = 0;
int n = 0;
Set<Integer> remainingSegments = new HashSet<>();
for (Address node : nodes) {
Collection<Integer> primarySegments = segmentsPerServer.get(n);
int shardQuantity = shardsNumPerServer[n];
if (shardQuantity == 0) {
remainingSegments.addAll(segmentsPerServer.get(n++));
continue;
}
shardsPerAddressMap.computeIfAbsent(node, a -> new HashSet<>(shardQuantity));
List<Set<Integer>> segments = this.split(primarySegments, shardsNumPerServer[n++]);
for (Collection<Integer> shardSegments : segments) {
String id = String.valueOf(shardId++);
shardSegments.forEach(seg -> shardPerSegmentMap.put(seg, id));
shardsPerAddressMap.get(node).add(id);
addressPerShardMap.put(id, node);
}
}
if (!remainingSegments.isEmpty()) {
Iterator<String> shardIterator = Stream.iterate(0, i -> (i + 1) % numShards).map(String::valueOf).iterator();
for (Integer segment : remainingSegments) {
shardPerSegmentMap.put(segment, shardIterator.next());
}
}
}
/**
* Allocates shards in a round robin fashion for the servers, ignoring those without segments.
*
* @return int[] with the number of shards per server
*/
private static int[] allocateShardsToNodes(int numShards, int numNodes, List<Set<Integer>> weightPerServer) {
int[] shardsPerServer = new int[numNodes];
Iterator<Integer> cyclicNodeIterator = Stream.iterate(0, i -> (i + 1) % numNodes).iterator();
while (numShards > 0) {
int slot = cyclicNodeIterator.next();
if (!weightPerServer.get(slot).isEmpty()) {
shardsPerServer[slot]++;
numShards--;
}
}
return shardsPerServer;
}
@Override
public Set<String> getShardsIdentifiers() {
return Collections.unmodifiableSet(addressPerShardMap.keySet());
}
@Override
public Set<String> getShards(Address address) {
return shardsPerAddressMap.get(address);
}
@Override
public String getShardFromSegment(Integer segment) {
return shardPerSegmentMap.get(segment);
}
@Override
public Address getOwner(String shardId) {
return addressPerShardMap.get(shardId);
}
}