package com.spbsu.bernulli.MCMCBernoulliMixture; import com.spbsu.commons.random.FastRandom; import com.spbsu.commons.util.ArrayTools; import gnu.trove.list.array.TIntArrayList; import java.util.Arrays; import static com.spbsu.commons.math.MathTools.sqr; //estimates some vector parameter of distibution by sampling with metropolis-hastings public class MCMCBernoulliEstimation { private final int k; //number of mixtures private final int n; //number of observations private final int[] sums; private final int[] stateSums; private final int[] componentsMap; private TIntArrayList[] componentsPoints; double[] param; private final FastRandom rand; private final BernoulliPrior prior; private Estimator estimator; private final double[] likelihoods; private final double[] logSizesCache; private final boolean[] isLogProbCached; private long accepted; private long rejected; public double acceptedRate() { return ((double) accepted) / (accepted + rejected); } private boolean burnIn = false; // take values only after burn in private final int window = 1000; // decorrelation, for better estimation. (we don't have infinite sample size) private final double logStepProb; public MCMCBernoulliEstimation(int k, int n, int sums[], BernoulliPrior prior, FastRandom random) { this.k = k; this.logStepProb = -Math.log(k) - Math.log(k - 1); ; this.n = n; this.rand = new FastRandom(random.nextLong()); this.prior = prior; this.stateSums = new int[k]; this.componentsMap = new int[sums.length]; this.likelihoods = new double[k]; this.sums = sums; this.estimator = new Estimator(sums.length); this.componentsPoints = new TIntArrayList[k]; this.param = new double[k]; for (int i = 0; i < k; ++i) this.componentsPoints[i] = new TIntArrayList(sums.length); this.randomState(); this.logSizesCache = new double[n * sums.length + 1]; this.isLogProbCached = new boolean[n * sums.length + 1]; } private void randomState() { Arrays.fill(stateSums, 0); for (int i = 0; i < k; ++i) componentsPoints[i].clear(); for (int i = 0; i < componentsMap.length; ++i) { final int comp = rand.nextByte(k); componentsMap[i] = comp; stateSums[comp] += sums[i]; componentsPoints[comp].add(i); } updateLikelihood(); } private int move(int from, int entry, int to) { final int point = componentsPoints[from].get(entry); this.stateSums[from] -= sums[point]; this.stateSums[to] += sums[point]; final int lastInd = componentsPoints[from].size() - 1; componentsPoints[from].set(entry, componentsPoints[from].get(lastInd)); componentsPoints[from].removeAt(lastInd); componentsPoints[to].add(point); componentsMap[point] = to; updateLikelihood(from); updateLikelihood(to); return componentsPoints[to].size() - 1; } private void updateLikelihood(int i) { cachedLL -= likelihoods[i]; likelihoods[i] = prior.likelihood(stateSums[i], componentsPoints[i].size() * n); cachedLL += likelihoods[i]; } private void updateLikelihood() { for (int i = 0; i < likelihoods.length; ++i) updateLikelihood(i); cachedLL = 0; for (int i = 0; i < stateSums.length; ++i) cachedLL += likelihoods[i]; } private double cachedLL = 0; final double likelihood() { return cachedLL; } final double getNewLL(final int from, final int entry, final int to) { double ll = likelihood(); ll -= likelihoods[from]; ll -= likelihoods[to]; final int point = componentsPoints[from].get(entry); final int sum = sums[point]; final int s0 = stateSums[from] - sum; final int s1 = stateSums[to] + sum; ll += prior.likelihood(s0, (componentsPoints[from].size() - 1) * n); ll += prior.likelihood(s1, (componentsPoints[to].size() + 1) * n); return ll; } final boolean next() { final double currentLL = likelihood(); final int moveFrom = rand.nextByte(k); if (componentsPoints[moveFrom].size() <= 1) return false; int moveTo = rand.nextByte(k - 1); if (moveTo >= moveFrom) ++moveTo; final int entry = rand.nextInt(componentsPoints[moveFrom].size()); final double prob = getProb(componentsPoints[moveFrom].size()); final double invProb = getProb(componentsPoints[moveTo].size() + 1); final double newLL = getNewLL(moveFrom, entry, moveTo); final double accProb = Math.exp(newLL + invProb - prob - currentLL); if (rand.nextDouble() < accProb) { move(moveFrom, entry, moveTo); return true; } return false; } private double getProb(int size) { if (isLogProbCached[size]) { return logSizesCache[size]; } else { logSizesCache[size] = logStepProb - Math.log(size); isLogProbCached[size] = true; return logSizesCache[size]; } } //it'll be inlined private void save() { addEB(); // add(); } public void run(int iters) { { int it = 0; double currentMeans[] = new double[componentsMap.length]; int burnIters = 100000; while (!burnIn) { for (int i = 0; i < burnIters; ++i, ++it) { if (!next()) { ++rejected; } else { ++accepted; } if (i % window == 0) save(); } double[] means = estimation(); if (burned(means, currentMeans)) { burnIn = true; } burnIters *= 2; estimator.clear(); currentMeans = means; System.out.println("Accepted rate after " + it + " iters is " + acceptedRate()); } } for (int i = 0; i < iters; ++i) { if (!next()) { ++rejected; } else { ++accepted; } if (i % window == 0) save(); } System.out.println("Accepted rate " + acceptedRate()); } private boolean burned(double[] means, double[] currentMeans) { return dist(means, currentMeans) < 1; } private double dist(double[] first, double[] second) { final int len = (first.length / 4) * 4; double sum = 0; for (int i = 0; i < len; i += 4) { double diff0 = first[i] - second[i]; double diff1 = first[i + 1] - second[i + 1]; double diff2 = first[i + 2] - second[i + 2]; double diff3 = first[i + 3] - second[i + 3]; diff0 *= diff0; diff1 *= diff1; diff2 *= diff2; diff3 *= diff3; diff0 += diff2; diff1 += diff3; sum += diff0 + diff1; } for (int i = len; i < first.length; ++i) sum += sqr(first[i] - second[i]); return sum; } public final double[] estimation() { return estimator.get(); } public void clear() { estimator.clear(); } private void parameterEstimation() { for (int i = 0; i < k; ++i) { param[i] = stateSums[i] * 1.0 / n / componentsPoints[i].size(); } } final void add() { estimator.count++; parameterEstimation(); final int len = (componentsMap.length / 4) * 4; for (int i = 0; i < len; i += 4) { final int ind0 = componentsMap[i]; final int ind1 = componentsMap[i + 1]; final int ind2 = componentsMap[i + 2]; final int ind3 = componentsMap[i + 3]; estimator.meanSums[i] += param[ind0]; estimator.meanSums[i + 1] += param[ind1]; estimator.meanSums[i + 2] += param[ind2]; estimator.meanSums[i + 3] += param[ind3]; } for (int i = len; i < componentsMap.length; ++i) { final int ind0 = componentsMap[i]; estimator.meanSums[i] += param[ind0]; } } final void addEB() { //use model only for estimation of "true" parameters only estimator.count++; parameterEstimation(); final double[] logtheta = new double[k]; final double[] logntheta = new double[k]; for (int i = 0; i < param.length; ++i) { logtheta[i] = Math.log(param[i]); logntheta[i] = Math.log(1 - param[i]); } final double[] prior = new double[k]; final double[] posterior = new double[k]; for (int i = 0; i < prior.length; ++i) { prior[i] = Math.log(componentsPoints[i].size() * 1.0 / sums.length); } for (int i = 0; i < sums.length; ++i) { final double m = sums[i]; double denum = 0; if (m == 0) { for (int j = 0; j < k; ++j) { posterior[j] = Math.exp((n - m) * logntheta[j] + prior[j]); denum += posterior[j]; } } else if (m == n) { for (int j = 0; j < k; ++j) { posterior[j] = Math.exp(m * logtheta[j] + prior[j]); denum += posterior[j]; } } else { for (int j = 0; j < k; ++j) { posterior[j] = Math.exp(m * logtheta[j] + (n - m) * logntheta[j] + prior[j]); denum += posterior[j]; } } double est = 0; for (int j = 0; j < k; ++j) { est += posterior[j] * param[j] / denum; } estimator.meanSums[i] += est; } } private class Estimator { final double[] meanSums; private int count; //normalization for estimator public Estimator(int len) { this.meanSums = new double[len]; } public final double[] get() { double[] result = new double[meanSums.length]; System.arraycopy(meanSums, 0, result, 0, result.length); for (int i = 0; i < meanSums.length; ++i) result[i] /= count; return result; } public final void clear() { ArrayTools.fill(meanSums, 0); count = 0; } } }