package com.spbsu.bernulli.naiveMixture; import com.spbsu.bernulli.EM; import com.spbsu.commons.math.vectors.Mx; import com.spbsu.commons.math.vectors.impl.mx.VecBasedMx; import com.spbsu.commons.random.FastRandom; import java.util.Arrays; /** * Created by noxoomo on 09/03/15. * <p/> * Mixture of bernoulli coins * let's q = (q_0,…,q_k) be some distribution (\sum q_i = 1) on (\theta_1, …, \theta_k) * we observe sequence of bernoulli events: * 1) choose \mu_i ~ q, (\mu_i = \theta_j for some j) * 2) toss n_i times a "coin" with parameter \mu_i * <p/> * Task: estimate \mu_i for every i * Subtask: estimate q_i and \theta_i */ public class BernoulliMixtureEM extends EM<NaiveMixture> { public double[] estimate(boolean needFit) { if (needFit) { fit(); } final double p[] = new double[sums.length]; expectation(); for (int j = 0; j < sums.length; ++j) { double prob = 0; for (int i = 0; i < k; ++i) { prob += dummy.get(i, j) * theta[i]; } p[j] = prob; } return p; } final int[] sums; final int total; final int k; final Mx dummy; final double[] cache; final double[] q; final double[] logq; final double[] theta; final double[] logtheta; final double[] logntheta; final FastRandom rand; public BernoulliMixtureEM(int[] sums, int total, int k, FastRandom rand) { this.sums = sums; this.total = total; this.k = k; this.dummy = new VecBasedMx(k, sums.length); cache = new double[k]; q = new double[k]; logq = new double[k]; theta = new double[k]; logtheta = new double[k]; logntheta = new double[k]; this.rand = rand; init(); } private void init() { double totalWeight = 0; for (int i = 0; i < q.length; ++i) { q[i] = rand.nextDouble(); totalWeight += q[i]; } for (int i = 0; i < q.length; ++i) { q[i] /= totalWeight; logq[i] = Math.log(q[i]); } for (int i = 0; i < theta.length; ++i) { theta[i] = rand.nextDouble(); } for (int i = 0; i < theta.length; ++i) { logtheta[i] = Math.log(theta[i]); logntheta[i] = Math.log(1 - theta[i]); } Arrays.sort(theta); } @Override protected void expectation() { final double n = total; for (int j = 0; j < sums.length; ++j) { final double m = sums[j]; double denum = 0; final int length = (k / 4) * 4; for (int i = 0; i < length; i += 4) { final double llt0; final double llt1; final double llt2; final double llt3; if (m != 0) { llt0 = logtheta[i]; llt1 = logtheta[i + 1]; llt2 = logtheta[i + 2]; llt3 = logtheta[i + 3]; } else { llt0 = 0; llt1 = 0; llt2 = 0; llt3 = 0; } final double rlt0; final double rlt1; final double rlt2; final double rlt3; if (m != n) { rlt0 = logntheta[i]; rlt1 = logntheta[i + 1]; rlt2 = logntheta[i + 2]; rlt3 = logntheta[i + 3]; } else { rlt0 = 0; rlt1 = 0; rlt2 = 0; rlt3 = 0; } final double lq0 = logq[i]; final double lq1 = logq[i + 1]; final double lq2 = logq[i + 2]; final double lq3 = logq[i + 3]; final double tmp0 = m * llt0 + (n - m) * rlt0 + lq0; final double tmp1 = m * llt1 + (n - m) * rlt1 + lq1; final double tmp2 = m * llt2 + (n - m) * rlt2 + lq2; final double tmp3 = m * llt3 + (n - m) * rlt3 + lq3; cache[i] = Math.exp(tmp0); cache[i + 1] = Math.exp(tmp1); cache[i + 2] = Math.exp(tmp2); cache[i + 3] = Math.exp(tmp3); denum += (cache[i] + cache[i + 1] + cache[i + 2] + cache[i + 3]); } for (int i = length; i < k; ++i) { double tmp = m != 0 ? m * logtheta[i] : 0; tmp += (n - m) != 0 ? (n - m) * logntheta[i] : 0; tmp += logq[i]; cache[i] = Math.exp(tmp); denum += cache[i]; } for (int i = 0; i < k; ++i) { dummy.set(i, j, cache[i] != 0 && denum != 0 ? cache[i] / denum : 0); } // //test // { // double totalWeight = 0; // for (int i = 0; i < k; ++i) { // totalWeight += dummy.get(i, j); // } // if (Math.abs(totalWeight - 1.0) > 1e-6) { // System.err.println("Error: probs should sum to one"); // } // } } } @Override protected double likelihood() { double ll = 0; final double n = total; for (int j = 0; j < sums.length; ++j) { final double m = sums[j]; double p = 0; //hotspot generates sse instructions final int length = (k / 4) * 4; for (int i = 0; i < length; i += 4) { final double llt0; final double llt1; final double llt2; final double llt3; if (m != 0) { llt0 = logtheta[i]; llt1 = logtheta[i + 1]; llt2 = logtheta[i + 2]; llt3 = logtheta[i + 3]; } else { llt0 = 0; llt1 = 0; llt2 = 0; llt3 = 0; } final double rlt0; final double rlt1; final double rlt2; final double rlt3; if (m != n) { rlt0 = logntheta[i]; rlt1 = logntheta[i + 1]; rlt2 = logntheta[i + 2]; rlt3 = logntheta[i + 3]; } else { rlt0 = 0; rlt1 = 0; rlt2 = 0; rlt3 = 0; } final double lq0 = logq[i]; final double lq1 = logq[i + 1]; final double lq2 = logq[i + 2]; final double lq3 = logq[i + 3]; final double tmp0 = m * llt0 + (n - m) * rlt0 + lq0; final double tmp1 = m * llt1 + (n - m) * rlt1 + lq1; final double tmp2 = m * llt2 + (n - m) * rlt2 + lq2; final double tmp3 = m * llt3 + (n - m) * rlt3 + lq3; final double p0 = Math.exp(tmp0); final double p1 = Math.exp(tmp1); final double p2 = Math.exp(tmp2); final double p3 = Math.exp(tmp3); final double p02 = p0 + p2; final double p13 = p1 + p3; p += p02 + p13; } for (int i = length; i < k; ++i) { double tmp = m != 0 ? m * logtheta[i] : 0; tmp += (n - m) != 0 ? (n - m) * logntheta[i] : 0; tmp += logq[i]; p += Math.exp(tmp); } ll += Math.log(p); } return ll; } @Override protected void maximization() { for (int i = 0; i < k; ++i) { double M = 0; double N = 0; double p = 0; int length = (sums.length / 4) * 4; for (int j = 0; j < length; j += 4) { final double prob0 = dummy.get(i, j); final double prob1 = dummy.get(i, j + 1); final double prob2 = dummy.get(i, j + 2); final double prob3 = dummy.get(i, j + 3); final double s0 = sums[j]; final double s1 = sums[j + 1]; final double s2 = sums[j + 2]; final double s3 = sums[j + 3]; final double m0 = prob0 * s0 + prob2 * s2; final double n0 = prob0 * total + prob2 * total; final double p0 = prob0 + prob2; final double m1 = prob1 * s1 + prob3 * s3; final double n1 = prob1 * total + prob3 * total; final double p1 = prob1 + prob3; M += m0 + m1; N += n0 + n1; p += p0 + p1; } for (int j = length; j < sums.length; ++j) { final double prob = dummy.get(i, j); M += prob * sums[j]; N += prob * total; p += prob; } theta[i] = M / N; logtheta[i] = Math.log(theta[i]); logntheta[i] = Math.log(1-theta[i]); q[i] = p / sums.length; logq[i] = Math.log(q[i]); } //test // { // double totalWeight = 0; // for (int i = 0; i < q.length; ++i) { // totalWeight += q[i]; // } // if (Math.abs(totalWeight - 1.0) > 1e-6) { // System.err.println("Error: probs should sum to one"); // } // } } int count = 100; double oldLikelihood = Double.NEGATIVE_INFINITY; @Override protected boolean stop() { if (count % 10 == 0) { final double currentLL = likelihood(); if (oldLikelihood + 1e-5 >= currentLL) { return true; } oldLikelihood = currentLL; } return --count <= 0; } @Override public NaiveMixture model() { return new NaiveMixture(theta, total, rand); } }