package com.spbsu.bernulli;
import static com.spbsu.commons.math.MathTools.sqr;
public class MixtureObservations<MixtureDistribution> {
public final MixtureDistribution owner;
public final int components[];
public final double[] thetas; //theta[i] ~ Beta(alpha[components[i]], beta[components[i]])
public final int sums[]; //sums[i] = Sum Bernoulli(theta[i]
public final int n;
public MixtureObservations(MixtureDistribution owner, int[] components, double[] thetas, int[] sums, int n) {
this.owner = owner;
this.components = components;
this.thetas = thetas;
this.sums = sums;
this.n = n;
}
public final double quality(final double[] estimator) {
double sum = 0;
final int len = (estimator.length / 4) * 4;
for (int i = 0; i < len; i += 4) {
double diff0 = thetas[i] - estimator[i];
double diff1 = thetas[i + 1] - estimator[i + 1];
double diff2 = thetas[i + 2] - estimator[i + 2];
double diff3 = thetas[i + 3] - estimator[i + 3];
diff0 *= diff0;
diff1 *= diff1;
diff2 *= diff2;
diff3 *= diff3;
diff0 += diff2;
diff1 += diff3;
sum += diff0 + diff1;
}
for (int i = len; i < estimator.length; ++i) {
sum += sqr(thetas[i] - estimator[i]);
}
return sum;
}
public final double naiveQuality() {
double sum = 0;
final int len = (thetas.length / 4) * 4;
for (int i = 0; i < len; i += 4) {
double p0 = sums[i] * 1.0 / n;
double p1 = sums[i + 1] * 1.0 / n;
double p2 = sums[i + 2] * 1.0 / n;
double p3 = sums[i + 3] * 1.0 / n;
double diff0 = thetas[i] - p0;
double diff1 = thetas[i + 1] - p1;
double diff2 = thetas[i + 2] - p2;
double diff3 = thetas[i + 3] - p3;
diff0 *= diff0;
diff1 *= diff1;
diff2 *= diff2;
diff3 *= diff3;
diff0 += diff2;
diff1 += diff3;
sum += diff0 + diff1;
}
for (int i = len; i < thetas.length; ++i) {
sum += sqr(thetas[i] - sums[i] * 1.0 / n);
}
return sum;
}
public double[] naive() {
double est[] = new double[sums.length];
final int len = (thetas.length / 4) * 4;
for (int i = 0; i < len; i += 4) {
est[i] = sums[i] * 1.0 / n;
est[i+1] = sums[i + 1] * 1.0 / n;
est[i+2] = sums[i + 2] * 1.0 / n;
est[i+3] = sums[i + 3] * 1.0 / n;
}
for (int i=len;i < thetas.length;++i)
est[i] = sums[i] * 1.0 / n;
return est;
}
}