import org.junit.Test;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.CoreMatchers.allOf;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThan;
public class GenerateNonuniformRandomNumbersTest {
private final int N = 10000;
private List<Integer> values;
private List<Double> probabilities;
@Test
public void nonUniformRandomNumberGeneration1() {
values = Arrays.asList(1, 2, 3, 4, 5);
probabilities = Arrays.asList(.2, .2, .2, .2, .2);
test(values, probabilities);
}
@Test
public void nonUniformRandomNumberGeneration2() {
values = Arrays.asList(1, 2, 3, 4);
probabilities = Arrays.asList(.1, .2, .3, .4);
test(values, probabilities);
}
@Test
public void nonUniformRandomNumberGeneration3() {
values = Arrays.asList(1, 2, 3, 4, 5);
probabilities = Arrays.asList(.9, .025, .025, .025, .025);
test(values, probabilities);
}
private void test(List<Integer> values, List<Double> probabilities) {
final Map<Integer, AtomicInteger> results = new ConcurrentHashMap<>(
values.stream()
.map(integer -> new AbstractMap.SimpleImmutableEntry<>(integer, new AtomicInteger(0)))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
);
IntStream.range(0, N)
.parallel()
.forEach(i -> {
results.get(GenerateNonuniformRandomNumbers.getRandom(values,probabilities)).incrementAndGet();
});
IntStream.range(0, values.size())
.parallel()
.forEach(i -> {
double expectedValue = N * probabilities.get(i);
assertThat(results.get(values.get(i)), allOf(greaterThan(expectedValue - 50), lessThan(expectedValue + 50)));
});
}
}