package edu.umd.cloud9.example.clustering; import java.util.Random; import java.util.Set; import com.google.common.collect.Sets; public class ExpectationMaximization { // Maximum number of iterations permitted. private static int MAX_ITERATIONS = 30; /** * Initializes the mixture model with points that are closet the given means. */ public static UnivariateGaussianMixtureModel initialize(Point[] points, double[] means) { UnivariateGaussianMixtureModel mm = new UnivariateGaussianMixtureModel(means.length); for (int i = 0; i < means.length; i++) { mm.weight[i] = (float) 1/means.length; PVector param = new PVector(2); Point tmpPoint = null; double minD = Double.MAX_VALUE; for (int j=0; j<points.length; j++) { double d = Math.abs(points[j].value - means[i]); if ( d < minD ) { tmpPoint = points[j]; minD = d; } } param.array[0] = tmpPoint.value; param.array[1] = 1; mm.param[i] = param; } return mm; } /** * Initializes the mixture model with random points. */ public static UnivariateGaussianMixtureModel initialize(Point[] points, int n) { UnivariateGaussianMixtureModel mm = new UnivariateGaussianMixtureModel(n); Integer[] arr = sampleNUniquePoints(n, points.length); for (int i = 0; i < n; i++) { mm.weight[i] = (float) 1/n; PVector param = new PVector(2); param.array[0] = points[arr[i]].value; param.array[1] = 1; mm.param[i] = param; } return mm; } /** * Performs the Expectation-Maximization algorithm. The parameters estimated corresponds to * univariate Gaussian distributions. * * @param points point set * @param m initial mixture model * @return learned mixture model */ public static UnivariateGaussianMixtureModel run(Point[] points, UnivariateGaussianMixtureModel m) { UnivariateGaussianMixtureModel mixtureModel = m.clone(); // Variables int numComponents = mixtureModel.size; int numPoints = points.length; int n, k; int iterations = 0; double[][] p = new double[numPoints][numComponents]; // Initial log likelihood double logLikelihoodNew = logLikelihood(points, mixtureModel); double logLikelihoodThreshold = 10e-10; //Math.abs(logLikelihoodNew) * 0.01; double logLikelihoodOld; System.out.printf("Iteration %2d: LL = %12.6f\n", iterations, logLikelihoodNew); do { logLikelihoodOld = logLikelihoodNew; // E-step: computation of matrix P (fast version, we don't compute 1/f(x) for all P[i][j]) for (n = 0; n < numPoints; n++) { double sum = 0; for (k = 0; k < numComponents; k++) { double tmp = mixtureModel.weight[k] * UnivariateGaussianMixtureModel.densityOfGaussian(points[n], mixtureModel.param[k]); p[n][k] = tmp; sum += tmp; } for (k = 0; k < numComponents; k++) { p[n][k] /= sum; } } // M-step: computation of new Gaussians and the new weights for (k = 0; k < numComponents; k++) { // Variables double sum = 0; double mu = 0; double sigma = 0; // First step of the computation of new mu for (n = 0; n < numPoints; n++) { double w = p[n][k]; sum += w; mu += points[n].value * w; } mu /= sum; // Computation of new sigma for (n = 0; n < numPoints; n++) { double diff = points[n].value - mu; sigma += p[n][k] * diff * diff; } sigma /= sum; // Set new mu and sigma PVector param = new PVector(2); param.array[0] = mu; param.array[1] = sigma; mixtureModel.param[k] = param; mixtureModel.weight[k] = sum / numPoints; } // Update of iterations and log likelihood value iterations++; logLikelihoodNew = logLikelihood(points, mixtureModel); System.out.printf("Iteration %2d: LL = %12.6f\n", iterations, logLikelihoodNew); } while (Math.abs((logLikelihoodNew - logLikelihoodOld)/logLikelihoodOld) > logLikelihoodThreshold && iterations < MAX_ITERATIONS); return mixtureModel; } /** * Computes the log likelihood. * * @param points set of points * @param f mixture model * @return log likelihood */ private static double logLikelihood(Point[] points, UnivariateGaussianMixtureModel f) { double value = 0; for (int i = 0; i < points.length; i++) { value += Math.log(f.density(points[i])); } return value; } public static final Integer[] sampleNUniquePoints(int n, int length) { Random rand = new Random(); Set<Integer> set = Sets.newHashSet(); while ( set.size() < n ) { int r = rand.nextInt(length); if (!set.contains(r)) { set.add(r); } } return set.toArray(new Integer[set.size()]); } }