/**
*
*/
package org.cmg.ml.sam.sim;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.function.Function;
import org.apache.commons.math3.random.RandomGenerator;
/**
* @author loreti
*
*/
public class RandomGeneratorRegistry {
private static RandomGeneratorRegistry instance;
private RandomGenerator rg;
private HashMap<Thread, RandomGenerator> registry;
private RandomGeneratorRegistry() {
this.registry = new HashMap<>();
this.rg = new DefaultRandomGenerator();
}
public synchronized static RandomGeneratorRegistry getInstance() {
if (instance == null) {
instance = new RandomGeneratorRegistry();
}
return instance;
}
public synchronized void register( RandomGenerator rg ) {
registry.put(Thread.currentThread(), rg);
}
public synchronized void unregister() {
registry.remove(Thread.currentThread());
}
public synchronized RandomGenerator get( ) {
return retrieve( Thread.currentThread() );
}
private RandomGenerator retrieve(Thread currentThread) {
RandomGenerator rg = registry.get(currentThread);
if (rg == null) {
rg = this.rg;
}
return rg;
}
@SafeVarargs
public static <T> T uniform( T ... data ) {
RandomGenerator rg = getInstance().get();
return data[rg.nextInt(data.length)];
}
public static <T> T uniformSelect( Collection<T> collection ) {
if (collection.size()==0) {
System.out.println("IS EMPTY!!!");
return null;
}
int idx = 0;
if (collection.size()>1) {
RandomGenerator rg = getInstance().get();
idx = rg.nextInt(collection.size());
}
int counter = 0;
T last = null;
for (T t : collection) {
last = t;
if (counter == idx) {
return t;
} else {
counter++;
}
}
return last;
}
public static <T> T select( Collection<T> collection , Function<T,Double> weight ) {
if (collection.size()==0) {
System.out.println("IS EMPTY!!!");
return null;
}
double[] weightArray = new double[collection.size()];
double total = 0.0;
ArrayList<T> elements = new ArrayList<>();
int counter = 0;
for (T e : collection) {
Double w = weight.apply(e);
if (w == null) {
w = 0.0;
}
total += w;
weightArray[counter] = total;
elements.add(e);
counter++;
}
return select( elements , weightArray , total );
}
public static <T> T weightedSelect(T[] data, double[] weights) {
//double total = DoubleStream.of(weights).sum();
//Arrays.parallelPrefix(weights, Double::sum);
//return select(new ArrayList<T>(Arrays.asList(data)), weights, total);
double total = 0;
double[] weightsArray = new double[weights.length];
for (int i = 0; i < weights.length; i++) {
total += weights[i];
weightsArray[i] = total;
}
return select(new ArrayList<T>(Arrays.asList(data)), weightsArray, total);
}
private static <T> T select(ArrayList<T> elements, double[] weightArray, double total) {
if (total == 0) {
return null;
}
double val = total*rnd();
for (int i=0 ; i<weightArray.length; i++ ) {
if (val<weightArray[i]) {
return elements.get(i);
}
}
return null;
}
public static double rnd() {
RandomGenerator rg = getInstance().get();
return rg.nextDouble();
}
public static double normal(double mean, double sd) {
RandomGenerator rg = getInstance().get();
return rg.nextGaussian()*sd+mean;
}
}