package com.spbsu.bernulli.betaBinomialMixture;
import com.spbsu.bernulli.EM;
import com.spbsu.bernulli.caches.BetaCache;
import com.spbsu.bernulli.caches.Digamma1Cache;
import com.spbsu.bernulli.caches.DigammaCache;
import com.spbsu.commons.math.vectors.Mx;
import com.spbsu.commons.math.vectors.impl.mx.VecBasedMx;
import com.spbsu.commons.random.FastRandom;
import java.util.Arrays;
import static java.lang.Double.isNaN;
public class BetaBinomialMixtureEM extends EM<BetaBinomialMixture> {
final int k;
final int[] sums;
final int n;
final Mx dummy;
final BetaBinomialMixture model;
final FastRandom random;
final SpecialFunctionCache funcs[];
final double gradientCache[];
final double newtonCache[];
public BetaBinomialMixtureEM(int k, final int[] sums, final int n, FastRandom random) {
this.k = k; //components count
this.sums = sums;
this.n = n;
this.dummy = new VecBasedMx(sums.length, k);
this.model = new BetaBinomialMixture(k, n,random);
this.random = random;
this.funcs = new SpecialFunctionCache[k];
for (int i = 0; i < k; ++i) {
this.funcs[i] = new SpecialFunctionCache(this.model.alphas[i], this.model.betas[i], n);
}
this.gradientCache = new double[this.model.alphas.length * 2];
// it's k matrix, 3 elements block per component mixture
this.newtonCache = new double[this.model.alphas.length * 3];
this.oldPoint = new double[this.model.alphas.length * 2];
}
final private void updateCache() {
for (int i = 0; i < k; ++i) {
funcs[i].update(model.alphas[i], model.betas[i]);
}
}
@Override
protected void expectation() {
double[] probs = new double[k];
updateCache();
for (int i = 0; i < sums.length; ++i) {
final int m = sums[i];
double denum = 0;
for (int j = 0; j < k; ++j) {
probs[j] = model.q[j] * funcs[j].calculate(m, n);
denum += probs[j];
}
for (int j = 0; j < k; ++j) {
dummy.set(i, j, probs[j] /= denum);
}
}
}
private final int iterations = 2;
private final double gradientStartStep = 0.01;
private final double newtonStartStep = 0.01;
private final int gradientIters = 20;
private boolean newtonStep(double step) {
updateCache();
Arrays.fill(newtonCache, 0.0);
fillGradient();
final double psi1abdiff[] = new double[k];
final double psi1a[] = new double[k];
final double psi1b[] = new double[k];
for (int i = 0; i < k; ++i) {
psi1abdiff[i] = funcs[i].digamma1(Type.AlphaBeta, 0) - funcs[i].digamma1(Type.AlphaBeta, n);
psi1a[i] = funcs[i].digamma1(Type.Alpha, 0);
psi1b[i] = funcs[i].digamma1(Type.Beta, 0);
}
for (int i = 0; i < sums.length; ++i) {
final int m = sums[i];
for (int j = 0; j < k; ++j) {
final double psi1am = funcs[j].digamma1(Type.Alpha, m);
final double psi1bm = funcs[j].digamma1(Type.Beta, n - m);
final int idx0 = 3 * j;
final int idx1 = 3 * j + 1;
final int idx2 = 3 * j + 2;
final double p = dummy.get(i, j);
//
final double val0 = p * (-psi1a[j] + psi1am + psi1abdiff[j]);
final double val1 = p * (psi1abdiff[j]);
final double val2 = p * (-psi1b[j] + psi1bm + psi1abdiff[j]);
//
newtonCache[idx0] += val0;
newtonCache[idx1] += val1;
newtonCache[idx2] += val2;
}
}
boolean status = false;
for (int i = 0; i < k; ++i) {
final double dalpha = gradientCache[2 * i];
final double dbeta = gradientCache[2 * i + 1];
//matrix
final double a = newtonCache[3 * i]; //d/dalpha^2
final double b = newtonCache[3 * i + 1];//d/alphabeta
final double d = newtonCache[3 * i + 2];//d/beta^2
final double det = a * d - b * b;
final double stepAlpha = (d * dalpha - b * dbeta) / det;
final double stepBeta = (a * dbeta - b * dalpha) / det;
model.alphas[i] -= step * stepAlpha;
model.betas[i] -= step * stepBeta;
if (model.betas[i] < 0.001) {
model.betas[i] = 0.001;
status = true;
}
if (model.alphas[i] < 0.001) {
model.alphas[i] = 0.001;
status = true;
}
}
return status;
}
private void fillGradient() {
Arrays.fill(gradientCache, 0.0);
final double psiasum[] = new double[k];
final double psibsum[] = new double[k];
for (int i = 0; i < k; ++i) {
final double psiab = funcs[i].digamma(Type.AlphaBeta, 0);
final double psiabn = funcs[i].digamma(Type.AlphaBeta, n);
final double psia = funcs[i].digamma(Type.Alpha, 0);
final double psib = funcs[i].digamma(Type.Beta, 0);
psiasum[i] = -psia + psiab - psiabn;
psibsum[i] = -psib + psiab - psiabn;
}
for (int i = 0; i < sums.length; ++i) {
final int m = sums[i];
for (int j = 0; j < k; ++j) {
final double alphaGrad = dummy.get(i, j) * (psiasum[j] + funcs[j].digamma(Type.Alpha, m));
final double betaGrad = dummy.get(i, j) * (psibsum[j] + funcs[j].digamma(Type.Beta, n - m));
gradientCache[2 * j] += alphaGrad;
gradientCache[2 * j + 1] += betaGrad;
}
}
}
private boolean gradientStep(double step) {
updateCache();
fillGradient();
for (int i = 0; i < gradientCache.length; ++i) {
if (isNaN(gradientCache[i])) {
return true;
}
}
boolean status = false;
for (int i = 0; i < k; ++i) {
model.alphas[i] += step * gradientCache[2 * i];
model.betas[i] += step * gradientCache[2 * i + 1];
if (model.betas[i] < 0.001) {
model.betas[i] = 0.001;
status = true;
}
if (model.alphas[i] < 0.001) {
model.alphas[i] = 0.001;
status = true;
}
}
return status;
}
// @Override
// protected void maximization() {
// double probs[] = new double[k];
// for (int i = 0; i < sums.length; ++i) {
// for (int j = 0; j < k; ++j) {
// probs[j] += dummy.get(i, j);
// }
// }
// double total = 0;
// for (int i = 0; i < k; ++i) {
// total += probs[i];
// }
// for (int i = 0; i < k; ++i)
// model.q[i] = probs[i] / total;
//
// double step = startStep;
// double ll = likelihood();
// int iters = iterations;
// while (true) {
// for (int i = 0; i < iters; ++i)
// if (gradientStep(step))
// break;
// double gradientNorm = 0;
// for (int i = 0; i < gradientCache.length; ++i) {
// gradientNorm += sqr(gradientCache[i]);
// }
// if (gradientNorm / gradientCache.length < 1e-2)
// return;
//
// double currentLL = likelihood();
// if (currentLL + 0.01 >= ll || step < 1e-6) {
// return;
// }
// step *= 0.5;
// }
// }
boolean first = true;
final int maxIterations = 5;
@Override
protected void maximization() {
double probs[] = new double[k];
for (int i = 0; i < sums.length; ++i) {
for (int j = 0; j < k; ++j) {
probs[j] += dummy.get(i, j);
}
}
double total = 0;
for (int i = 0; i < k; ++i) {
total += probs[i];
}
for (int i = 0; i < k; ++i)
model.q[i] = probs[i] / total;
double gradientStep = gradientStartStep;
for (int i = 0; i < gradientIters; ++i) {
if (gradientStep(gradientStep)) {
return;
}
}
// first = true;
// for (int i = 0; i < iterations; ++i) {
// if (newtonStep(newtonStartStep)) {
// return;
// }
// }
// shrinkage();
}
final int maxObservations = 200;
//heuristic for damn singularities.
private void shrinkage() {
for (int i = 0; i < model.alphas.length; ++i) {
final double alpha = model.alphas[i];
final double beta = model.betas[i];
double observations = alpha + beta;
if (observations > maxObservations) {
double m = alpha / observations;
model.alphas[i] = maxObservations * m * 0.9;
model.betas[i] = maxObservations * 0.9;
}
}
}
int count = 50;
double[] oldPoint;
double oldLikelihood = Double.NEGATIVE_INFINITY;
@Override
protected boolean stop() {
final double currentLL = likelihood();
if (oldLikelihood + 1e-1 >= currentLL) {
return true;
}
oldLikelihood = currentLL;
count--;
if (count < 0)
return true;
return false;
}
@Override
public BetaBinomialMixture model() {
return model;
}
@Override
protected double likelihood() {
updateCache();
double ll = 0;
for (int i = 0; i < sums.length; ++i) {
double p = 0;
final int m = sums[i];
for (int j = 0; j < model.alphas.length; ++j) {
p += model.q[j] * funcs[j].calculate(m, n);
}
ll += Math.log(p);
}
return ll;
}
//emperical bayes estimation with fitted beta-mixture
public double[] estimate(boolean fit) {
if (fit) {
fit();
}
expectation();
double result[] = new double[sums.length];
for (int i = 0; i < sums.length; ++i) {
for (int j = 0; j < k; ++j)
result[i] += dummy.get(i, j) * (sums[i] + model.alphas[j]) * 1.0 / (n + model.betas[j] + model.alphas[j]);
}
return result;
}
public enum Type {
Alpha,
AlphaBeta, Beta
}
public static class SpecialFunctionCache {
DigammaCache[] digammaCaches = new DigammaCache[3];
Digamma1Cache[] digamma1Caches = new Digamma1Cache[3];
BetaCache betaCache;
public SpecialFunctionCache(double alpha, double beta, int n) {
betaCache = new BetaCache(alpha,beta,n);
digammaCaches[0] = new DigammaCache(alpha, n);
digammaCaches[1] = new DigammaCache(beta, n);
digammaCaches[2] = new DigammaCache(alpha + beta, n);
digamma1Caches[0] = new Digamma1Cache(alpha, n);
digamma1Caches[1] = new Digamma1Cache(beta, n);
digamma1Caches[2] = new Digamma1Cache(alpha + beta, n);
}
public double calculate(int m, int n) {
return betaCache.calculate(m,n);
}
final public double digamma(Type type, int offset) {
if (type == Type.Alpha) {
return digammaCaches[0].calculate(offset);
} else if (type == Type.Beta) {
return digammaCaches[1].calculate(offset);
} else {
return digammaCaches[2].calculate(offset);
}
}
public double digamma1(Type type, int offset) {
if (type == Type.Alpha) {
return digamma1Caches[0].calculate(offset);
} else if (type == Type.Beta) {
return digamma1Caches[1].calculate(offset);
} else {
return digamma1Caches[2].calculate(offset);
}
}
final public void update(double alpha, double beta) {
betaCache.update(alpha,beta);
digammaCaches[0].update(alpha);
digammaCaches[1].update(beta);
digammaCaches[2].update(alpha + beta);
digamma1Caches[0].update(alpha);
digamma1Caches[1].update(beta);
digamma1Caches[2].update(alpha + beta);
}
}
}