/*
* Copyright (C) 2015 SoftIndex LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.datakernel.rpc.client.sender;
import io.datakernel.async.ResultCallback;
import io.datakernel.rpc.client.RpcClientConnectionPool;
import java.net.InetSocketAddress;
import java.util.*;
import static io.datakernel.util.Preconditions.checkArgument;
public class RpcStrategyRandomSampling implements RpcStrategy {
private final Random random = new Random();
private final Map<RpcStrategy, Integer> strategyToWeight = new HashMap<>();
private RpcStrategyRandomSampling() {}
public static RpcStrategyRandomSampling create() {return new RpcStrategyRandomSampling();}
public RpcStrategyRandomSampling add(int weight, RpcStrategy strategy) {
checkArgument(weight >= 0, "weight cannot be negative");
checkArgument(!strategyToWeight.containsKey(strategy), "withStrategy is already added");
strategyToWeight.put(strategy, weight);
return this;
}
@Override
public Set<InetSocketAddress> getAddresses() {
HashSet<InetSocketAddress> result = new HashSet<>();
for (RpcStrategy strategy : strategyToWeight.keySet()) {
result.addAll(strategy.getAddresses());
}
return result;
}
@Override
public RpcSender createSender(RpcClientConnectionPool pool) {
Map<RpcSender, Integer> senderToWeight = new HashMap<>();
int totalWeight = 0;
for (RpcStrategy rpcStrategy : strategyToWeight.keySet()) {
RpcSender sender = rpcStrategy.createSender(pool);
if (sender != null) {
int weight = strategyToWeight.get(rpcStrategy);
senderToWeight.put(sender, weight);
totalWeight += weight;
}
}
if (totalWeight == 0) {
return null;
}
long randomLong = random.nextLong();
long seed = randomLong != 0L ? randomLong : 2347230858016798896L;
return new RandomSamplingSender(senderToWeight, seed);
}
private static final class RandomSamplingSender implements RpcSender {
private final List<RpcSender> senders;
private final int[] cumulativeWeights;
private final int totalWeight;
private long lastRandomLong;
public RandomSamplingSender(Map<RpcSender, Integer> senderToWeight, long seed) {
checkArgument(!senderToWeight.containsKey(null), "sender cannot be null");
senders = new ArrayList<>(senderToWeight.size());
cumulativeWeights = new int[senderToWeight.size()];
int currentCumulativeWeight = 0;
int currentSender = 0;
for (RpcSender rpcSender : senderToWeight.keySet()) {
currentCumulativeWeight += senderToWeight.get(rpcSender);
senders.add(rpcSender);
cumulativeWeights[currentSender++] = currentCumulativeWeight;
}
totalWeight = currentCumulativeWeight;
lastRandomLong = seed;
}
@Override
public <I, O> void sendRequest(I request, int timeout, ResultCallback<O> callback) {
chooseSender().sendRequest(request, timeout, callback);
}
private RpcSender chooseSender() {
int currentRandomValue = (int) ((generateRandomLong() & Long.MAX_VALUE) % totalWeight);
int senderIndex = binarySearch(currentRandomValue, 0, cumulativeWeights.length);
return senders.get(senderIndex);
}
public long generateRandomLong() {
lastRandomLong ^= (lastRandomLong << 21);
lastRandomLong ^= (lastRandomLong >>> 35);
lastRandomLong ^= (lastRandomLong << 4);
return lastRandomLong;
}
private int binarySearch(int value, int lowerIndex, int upperIndex) {
if (lowerIndex == upperIndex) {
return lowerIndex;
}
int middle = (lowerIndex + upperIndex) / 2;
if (value >= cumulativeWeights[middle]) {
return binarySearch(value, middle + 1, upperIndex);
} else {
return binarySearch(value, lowerIndex, middle);
}
}
}
}