package com.spbsu.bernulli.naiveMixture;
import com.spbsu.bernulli.Mixture;
import com.spbsu.bernulli.MixtureObservations;
import com.spbsu.bernulli.Multinomial;
import com.spbsu.commons.random.FastRandom;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
/**
* Created by noxoomo on 27/03/15.
*/
public class NaiveMixture extends Mixture {
private final RandomGenerator apacheRand;
private final Multinomial multinomial;
private final double[] means;
private final int count;
public NaiveMixture(final double[] q,final int count, final FastRandom rand) {
super(q,rand);
this.apacheRand = new MersenneTwister(rand.nextLong());
this.multinomial = new Multinomial(rand, q);
this.means = new double[q.length];
for (int i=0; i < q.length;++i) {
this.means[i] = rand.nextDouble();
}
this.count = count;
}
public NaiveMixture(final int k,final int count, final FastRandom rand) {
super(k,rand);
this.apacheRand = new MersenneTwister(rand.nextLong());
this.multinomial = new Multinomial(rand, q);
this.means = new double[q.length];
for (int i=0; i < q.length;++i) {
this.means[i] = rand.nextDouble();
}
this.count = count;
}
public MixtureObservations<NaiveMixture> sample(int n) {
int[] components = new int[n];
double[] means = new double[n];
int[] sums = new int[n];
for (int i=0; i < components.length;++i) {
final int component = multinomial.next();
components[i] = component;
means[i] = this.means[component];
final BinomialDistribution dist = new BinomialDistribution(apacheRand,count,means[i]);
sums[i] = dist.sample();
}
return new MixtureObservations<>(this,components,means,sums,count);
}
}