/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.bookkeeper.client;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.bookkeeper.client.BookieInfoReader.BookieInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class WeightedRandomSelection<T> {
static final Logger LOG = LoggerFactory.getLogger(WeightedRandomSelection.class);
interface WeightedObject {
long getWeight();
}
Double randomMax;
int maxProbabilityMultiplier;
Map<T, WeightedObject> map;
TreeMap<Double, T> cummulativeMap = new TreeMap<Double, T>();
ReadWriteLock rwLock = new ReentrantReadWriteLock(true);
WeightedRandomSelection() {
maxProbabilityMultiplier = -1;
}
WeightedRandomSelection(int maxMultiplier) {
this.maxProbabilityMultiplier = maxMultiplier;
}
public void setMaxProbabilityMultiplier(int max) {
this.maxProbabilityMultiplier = max;
}
void updateMap(Map<T, WeightedObject> map) {
// get the sum total of all the values; this will be used to
// calculate the weighted probability later on
Long totalWeight = 0L, min= Long.MAX_VALUE;
List<WeightedObject> values = new ArrayList<WeightedObject>(map.values());
Collections.sort(values, new Comparator<WeightedObject>() {
public int compare(WeightedObject o1, WeightedObject o2) {
long diff = o1.getWeight() - o2.getWeight();
if (diff < 0L) {
return -1;
} else if (diff > 0L) {
return 1;
} else {
return 0;
}
}
});
for (int i=0; i < values.size(); i++) {
totalWeight += values.get(i).getWeight();
if (values.get(i).getWeight() != 0 && min > values.get(i).getWeight()) {
min = values.get(i).getWeight();
}
}
double median = 0;
if (totalWeight == 0) {
// all the values are zeros; assign a value of 1 to all and the totalWeight equal
// to the size of the values
min = 1L;
median = 1;
totalWeight = (long)values.size();
} else {
int mid = values.size()/2;
if ((values.size() % 2) == 1) {
median = values.get(mid).getWeight();
} else {
median = (double)(values.get(mid-1).getWeight() + values.get(mid).getWeight())/2;
}
}
double medianWeight, minWeight;
medianWeight = median/(double)totalWeight;
minWeight = (double)min/totalWeight;
LOG.debug("Updating weights map. MediaWeight: " + medianWeight + " MinWeight: " + minWeight);
double maxWeight = maxProbabilityMultiplier*medianWeight;
Map<T, Double> weightMap = new HashMap<T, Double>();
for (Map.Entry<T, WeightedObject> e : map.entrySet()) {
double weightedProbability;
if (e.getValue().getWeight() > 0) {
weightedProbability = (double)e.getValue().getWeight()/(double)totalWeight;
} else {
weightedProbability = minWeight;
}
if (maxWeight > 0 && weightedProbability > maxWeight) {
weightedProbability=maxWeight;
LOG.debug("Capping the probability to " + weightedProbability + " for " + e.getKey() + " Value: " + e.getValue());
}
weightMap.put(e.getKey(), weightedProbability);
}
// The probability of picking a bookie randomly is defaultPickProbability
// but we change that priority by looking at the weight that each bookie
// carries.
TreeMap<Double, T> tmpCummulativeMap = new TreeMap<Double, T>();
Double key=0.0;
for (Map.Entry<T, Double> e : weightMap.entrySet()) {
tmpCummulativeMap.put(key, e.getKey());
LOG.debug("Key: " + e.getKey() + " Value: " + e.getValue()
+ " AssignedKey: " + key + " AssignedWeight: " + e.getValue());
key += e.getValue();
}
rwLock.writeLock().lock();
try {
this.map = map;
cummulativeMap = tmpCummulativeMap;
randomMax = key;
} finally {
rwLock.writeLock().unlock();
}
}
T getNextRandom() {
rwLock.readLock().lock();
try {
// pick a random number between 0 and randMax
Double randomNum = randomMax*Math.random();
// find the nearest key in the map corresponding to the randomNum
Double key = cummulativeMap.floorKey(randomNum);
//LOG.info("Random max: " + randomMax + " CummulativeMap size: " + cummulativeMap.size() + " selected key: " + key);
return cummulativeMap.get(key);
} finally {
rwLock.readLock().unlock();
}
}
}