package com.plectix.simulator.simulationclasses.probability; import java.util.Collection; import java.util.LinkedHashMap; import java.util.Map; import java.util.Set; import com.plectix.simulator.simulator.ThreadLocalData; /** * This class select an item from a collection based on the item's weight. * The data structure is inspired from a Skip List (see {@link http://en.wikipedia.org/wiki/Skip_list}). * Though the current implementation differs. * * @author ecemis */ public class SkipListSelector<E extends WeightedItem> implements WeightedItemSelector<E> { private static final double P = 0.5; private double totalWeight = 0.0; /** currentLevel is -1 if there is no item stored */ private int currentLevel = -1; private final SkipListItem<E> head = new SkipListItem<E>(); private final SkipListItem<E> tail = new SkipListItem<E>(); private final Map<E, SkipListItem<E>> weightedItemToSkipListItemMap = new LinkedHashMap<E, SkipListItem<E>>(); public SkipListSelector() { super(); } public final E select() { if (currentLevel == -1) { return null; } return search(head, currentLevel, totalWeight * ThreadLocalData.getRandom().getDouble()).getWeightedItem(); } public final void updatedItems(Collection<E> changedWeightedItemList) { for (E weightedItem : changedWeightedItemList) { updatedItem(weightedItem); } } public final void updatedItem(E weightedItem) { final SkipListItem<E> skipListItem = weightedItemToSkipListItemMap.get(weightedItem); if (skipListItem == null) { if (weightedItem.getWeight() > 0.0) { // this item is new, we need to add weightedItemToSkipListItemMap.put(weightedItem, addWeightedItem(weightedItem)); } } else { // this item is old... what is the new weight? final double newWeight = weightedItem.getWeight(); if (newWeight <= 0.0) { deleteItem(skipListItem); // let's clean this item to make sure that we'll have NullPointerException if there are any bugs skipListItem.clear(); weightedItemToSkipListItemMap.remove(weightedItem); } else { updateWeight(skipListItem, newWeight); } } } private final SkipListItem<E> addWeightedItem(E weightedItem) { final double newWeight = weightedItem.getWeight(); totalWeight += newWeight; final SkipListItem<E> newItem = new SkipListItem<E>(); newItem.setWeightedItem(weightedItem); int level = getRandomLevel(); for (int i = 0; i <= level; i++) { if (i <= currentLevel) { final SkipListItem<E> oldLastItem = tail.getBackwardPointer(i); oldLastItem.setForwardPointer(i, newItem); tail.setBackwardPointer(i, newItem); newItem.addPointersAndSum(oldLastItem, tail, tail.getSum(i) + newWeight); tail.resetSum(i); } else { // creating a new level: head.addPointersAndSum(null, newItem, 0.0); tail.addPointersAndSum(newItem, null, 0.0); newItem.addPointersAndSum(head, tail, totalWeight); } } for (int i= level+1; i <= currentLevel ; i++) { tail.incrementSum(i, newWeight); } // update current level here, not above! if (level > currentLevel) { currentLevel = level; } return newItem; } public double getTotalWeight() { return totalWeight; } public Set<E> asSet() { return weightedItemToSkipListItemMap.keySet(); } private final void deleteItem(SkipListItem<E> skipListItem) { final double weightDiff = -skipListItem.getSum(0); totalWeight += weightDiff; int lastUpdatedLevel = skipListItem.getLevel() - 1; // System.err.println(dumpLevels() + " --> deleting an item with level " + lastUpdatedLevel); for (int i= 0; i <= lastUpdatedLevel; i++) { final SkipListItem<E> previousItemAtThatLevel = skipListItem.getBackwardPointer(i); final SkipListItem<E> nextItemAtThatLevel = skipListItem.getForwardPointer(i); previousItemAtThatLevel.setForwardPointer(i, nextItemAtThatLevel); nextItemAtThatLevel.setBackwardPointer(i, previousItemAtThatLevel); if (i != 0) { nextItemAtThatLevel.incrementSum(i, skipListItem.getSum(i) + weightDiff); } if (previousItemAtThatLevel == head && nextItemAtThatLevel == tail) { // we have to delete that level and up // System.err.println("========= DELETING ALL LEVELS AT OR ABOVE: " + i); head.removePointersAndSum(i); tail.removePointersAndSum(i); currentLevel = head.getLevel() - 1; // System.err.println(dumpLevels()); return; } } adjustWeightsForward(skipListItem, weightDiff, lastUpdatedLevel); } private final int getRandomLevel() { int ret = 0; while (ThreadLocalData.getRandom().getDouble() < P && ret <= currentLevel) { ret++; } return ret; } private final void updateWeight(SkipListItem<E> skipListItem, double newWeight) { final double weightDiff = newWeight - skipListItem.getSum(0); if (weightDiff == 0.0) { // there is no weight change! return; } totalWeight += weightDiff; int lastUpdatedLevel = skipListItem.getLevel() - 1; for (int i= 0; i <= lastUpdatedLevel; i++) { skipListItem.incrementSum(i, weightDiff); } adjustWeightsForward(skipListItem, weightDiff, lastUpdatedLevel); } private final void adjustWeightsForward(SkipListItem<E> skipListItem, final double weightDiff, int lastUpdatedLevel) { while (lastUpdatedLevel < currentLevel) { final SkipListItem<E> nextItemAtThatLevel = skipListItem.getForwardPointer(lastUpdatedLevel); for (int i= lastUpdatedLevel+1; i < nextItemAtThatLevel.getLevel(); i++) { nextItemAtThatLevel.incrementSum(i, weightDiff); } lastUpdatedLevel = nextItemAtThatLevel.getLevel()-1; skipListItem = nextItemAtThatLevel; } } private final SkipListItem<E> search(SkipListItem<E> skipListItem, int level, double randomValue) { // Originally I wrote this method with tail-recursion but then rewrote it with iteration to have it more efficient while (true) { if (randomValue == 0.0) { return skipListItem.getForwardPointer(0); } SkipListItem<E> nextItemAtThatLevel = skipListItem.getForwardPointer(level); while (nextItemAtThatLevel == tail) { if (level == 0) { // we can be here if there are some round-off errors!!! // let's return the last item then!!! return skipListItem; } level--; nextItemAtThatLevel = skipListItem.getForwardPointer(level); } while (nextItemAtThatLevel.getSum(level) > randomValue) { if (level == 0) { return nextItemAtThatLevel; } level--; nextItemAtThatLevel = skipListItem.getForwardPointer(level); } randomValue -= nextItemAtThatLevel.getSum(level); skipListItem = nextItemAtThatLevel; } } public final String levelsToString() { StringBuffer stringBuffer = new StringBuffer(); for (SkipListItem<E> skipListItem = head; skipListItem != null; skipListItem = skipListItem.getForwardPointer(0)) { if (skipListItem == head) { stringBuffer.append("H"); } stringBuffer.append(skipListItem.getLevel()); if (skipListItem == tail) { stringBuffer.append("T"); } stringBuffer.append("-"); } stringBuffer.append("(currentLevel=" + currentLevel + ") (totalWeight=" + totalWeight + ")"); return stringBuffer.toString(); } public final String weightsToString() { StringBuffer stringBuffer = new StringBuffer(); for (SkipListItem<E> skipListItem = head; skipListItem != null; skipListItem = skipListItem.getForwardPointer(0)) { if (skipListItem == head) { stringBuffer.append("H"); } stringBuffer.append(skipListItem.getSum(0)); if (skipListItem == tail) { stringBuffer.append("T"); } stringBuffer.append("-"); } stringBuffer.append("(totalWeight=" + totalWeight + ")"); return stringBuffer.toString(); } }