package de.invesdwin.util.math.decimal.internal.randomize;
import java.util.Iterator;
import java.util.List;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.commons.math3.random.RandomGenerator;
import de.invesdwin.util.bean.tuple.Pair;
import de.invesdwin.util.collections.Lists;
import de.invesdwin.util.math.decimal.ADecimal;
import de.invesdwin.util.math.decimal.IDecimalAggregate;
@NotThreadSafe
public class WeightedChunksAscendingRandomizer<E extends ADecimal<E>> implements IDecimalRandomizer<E> {
private final int sampleSize;
private final Pair<Double, ? extends List<E>>[] threshold_chunk;
@SuppressWarnings("unchecked")
public WeightedChunksAscendingRandomizer(final IDecimalAggregate<E> parent, final int chunkCount) {
this.sampleSize = parent.values().size();
final List<? extends List<E>> sampleChunks = Lists.splitIntoPackageCount(parent.values(), chunkCount);
double chunkWeightsSum = 0D;
for (double i = 1; i <= chunkCount; i++) {
chunkWeightsSum += i;
}
double chunkProbabilitiesSum = 0D;
threshold_chunk = new Pair[chunkCount];
for (int i = 1; i <= chunkCount; i++) {
final double chunkWeight = i;
final double chunkProbability = chunkWeight / chunkWeightsSum;
chunkProbabilitiesSum += chunkProbability;
final int chunkIndex = i - 1;
threshold_chunk[chunkIndex] = Pair.of(chunkProbabilitiesSum, sampleChunks.get(chunkIndex));
}
}
@Override
public Iterator<E> randomize(final RandomGenerator random) {
return new Iterator<E>() {
private int resampleIdx = 0;
@Override
public boolean hasNext() {
return resampleIdx < sampleSize;
}
@Override
public E next() {
final List<E> sampleChunk = getSampleChunk(random);
final int sourceIdx = random.nextInt(sampleChunk.size());
resampleIdx++;
return sampleChunk.get(sourceIdx);
}
};
}
private List<E> getSampleChunk(final RandomGenerator random) {
final double chunkThreshold = random.nextDouble();
for (int i = 0; i < threshold_chunk.length; i++) {
final Pair<Double, ? extends List<E>> pair = threshold_chunk[i];
final double threshold = pair.getFirst();
if (chunkThreshold <= threshold) {
return pair.getSecond();
}
}
throw new IllegalStateException("No chunk found for threshold: " + chunkThreshold);
}
}