package com.spbsu.bernulli.betaBinomialMixture; import com.spbsu.bernulli.Mixture; import com.spbsu.bernulli.MixtureObservations; import com.spbsu.bernulli.Multinomial; import com.spbsu.commons.random.FastRandom; import com.spbsu.commons.util.ArrayTools; import org.apache.commons.math3.distribution.BetaDistribution; import org.apache.commons.math3.distribution.BinomialDistribution; import org.apache.commons.math3.random.MersenneTwister; import org.apache.commons.math3.random.RandomGenerator; public class BetaBinomialMixture extends Mixture { public final double alphas[]; public final double betas[]; final RandomGenerator apacheRandom; final BetaDistribution[] samplers; final Multinomial multinomialSampler; private int count; void setCount(int count) { this.count = count; } public BetaBinomialMixture(double[] alphas, double[] betas, double[] q, int count) { this(alphas, betas, q, count, new FastRandom()); } public BetaBinomialMixture(double[] alphas, double[] betas, double[] q, int count, FastRandom rand) { super(q, rand); this.alphas = alphas; this.betas = betas; ArrayTools.parallelSort(q, alphas, betas); this.apacheRandom = new MersenneTwister(random.nextInt()); this.samplers = new BetaDistribution[q.length]; for (int i = 0; i < samplers.length; ++i) { samplers[i] = new BetaDistribution(apacheRandom, alphas[i], betas[i], BetaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); } multinomialSampler = new Multinomial(random, q); this.count = count; } private int randomInitLimit = 100; public BetaBinomialMixture(int k, int count, FastRandom random) { super(k, random); double[] alphas = new double[q.length]; double[] betas = new double[q.length]; for (int i = 0; i < q.length; ++i) { alphas[i] = random.nextDouble()*randomInitLimit; betas[i] = random.nextDouble()*randomInitLimit; } this.alphas = alphas; this.betas = betas; this.apacheRandom = new MersenneTwister(random.nextInt()); this.samplers = new BetaDistribution[q.length]; for (int i = 0; i < samplers.length; ++i) { samplers[i] = new BetaDistribution(apacheRandom, alphas[i], betas[i], BetaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); } multinomialSampler = new Multinomial(random, q); this.count = count; } public int sample() { BinomialDistribution sampler = new BinomialDistribution(apacheRandom, count, samplers[multinomialSampler.next()].sample()); return sampler.sample(); } @Override public MixtureObservations<BetaBinomialMixture> sample(int n) { final int components[] = new int[n]; final double[] thetas = new double[n]; final int sums[] = new int[n]; for (int i = 0; i < n; ++i) { components[i] = multinomialSampler.next(); thetas[i] = samplers[components[i]].sample(); BinomialDistribution sampler = new BinomialDistribution(apacheRandom, count, thetas[i]); sums[i] = sampler.sample(); } return new MixtureObservations<>(this, components, thetas, sums, count); } @Override public String toString() { StringBuilder builder = new StringBuilder(); builder.append(alphas.length).append(": "); for (int i = 0; i < alphas.length; ++i) { builder.append("(").append(q[i]).append(",").append(alphas[i]).append(",").append(betas[i]) .append(",").append(alphas[i] / (alphas[i]+betas[i])).append(")"); } return builder.toString(); } }