import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.text.DecimalFormat; import java.util.Arrays; import java.util.Random; public class h3 { public final static double SIGMA = 1.0; public static class Particle { public double mu; public double sigma; public double tau; public Particle(double sigma, double mu, double tau) { this.mu = mu; this.sigma = sigma; this.tau = tau; } public double probabilityOfXGivenMu(double xi) { //=EXP(0-($B11-E$8)^2/2)/SQRT(2*3.14) return Math.exp(-1 * Math.pow((xi - mu)/sigma, 2) / 2) / Math.sqrt(2 * Math.PI * sigma * sigma); } public double weightedProbabilityOfXGivenMu(double xi) { return tau * probabilityOfXGivenMu(xi); } } public static class Particles { public Particle[] particles; public double[] x; public double[][] z; public double sigma; public double tau; private String headerSeparator = ",";// " "; private String outputSeparator = ",";//" "; private Random rand; public Particles (int numParticles, double[] initMuValues, double[] x) { rand = new Random(); this.particles = new Particle[numParticles]; this.z = new double[x.length][numParticles]; this.x = x; for (int i = 0; i < numParticles; i++) { sigma = SIGMA; tau = 1.0 / numParticles; particles[i] = new Particle(sigma, initMuValues[i], tau); } } private void runExpectationPhase () { for (int i = 0; i < x.length; i++) { double totalProb = getTotalProbabilityForDataPoint(x[i]); for (int j = 0; j < particles.length; j++) { z[i][j] = particles[j].weightedProbabilityOfXGivenMu(x[i]) / totalProb; } } } private void runMaximizationPhase() { for (int j = 0; j < particles.length; j++) { double n = 0; double d = 0; for (int i = 0; i < x.length; i++) { n += (z[i][j] * x[i]); d += z[i][j]; } particles[j].mu = n / d; } } public void runEM() { double bestLogLikelihood = Double.MIN_VALUE; String bestOutput = null; int numAttempts = 0; while (numAttempts < 100) { StringBuilder sb = new StringBuilder(); int i = 0; // print headers String headers = printHeaders(); sb.append(headers); sb.append('\n'); sb.append(printCurrentMuValues(i)); sb.append('\n'); double prevMuTotal = getLogLikelihood(); double currMuTotal = prevMuTotal; if (numAttempts != 0) { this.generateRandomStartValues(); } // print each row do { prevMuTotal = currMuTotal; runExpectationPhase(); runMaximizationPhase(); i++; sb.append(printCurrentMuValues(i)); sb.append('\n'); currMuTotal = getLogLikelihood(); } while ((Math.abs(currMuTotal - prevMuTotal) > 0.001) && i < 200) ; // stop when the change in mu is less than 0.01 sb.append(printValuesAndProb()); if ((numAttempts==0) || this.getLogLikelihood() > bestLogLikelihood) { if (numAttempts!=0) { double delta = this.getLogLikelihood() - bestLogLikelihood; //System.out.println("Found a better log Value: " + delta); } bestLogLikelihood = this.getLogLikelihood(); bestOutput = sb.toString(); } numAttempts++; } System.out.println(bestOutput); } private void generateRandomStartValues() { double min = x[0]; double max = x[x.length - 1]; double delta = max - min; for (int i = 0 ; i < particles.length; i++) { particles[i].mu = rand.nextDouble() * delta + min; } } private String printValuesAndProb() { StringBuilder sb = new StringBuilder(); DecimalFormat df = new DecimalFormat("0.00E0"); sb.append("x_i"); for (int i=1; i <= particles.length; i++) { sb.append(headerSeparator); sb.append(String.format("P(mu%d | x_i)", i)); } sb.append('\n'); for (int i = 0; i < Math.min(x.length, 25); i++) { sb.append(x[i]); for (int j = 0; j < particles.length; j++) { sb.append(outputSeparator); //sb.append(particles[j].probabilityOfXGivenMu(x[i])); sb.append(df.format(z[i][j])); //sb.append(df.format(particles[j].probabilityOfXGivenMu(x[i]))); } sb.append('\n'); } return sb.toString(); } private String printCurrentMuValues(int iteration) { DecimalFormat df = new DecimalFormat("#.##"); StringBuilder sb = new StringBuilder(); sb.append(iteration); for (int i=1; i <= particles.length; i++) { sb.append(outputSeparator); sb.append(df.format(particles[i-1].mu)); } sb.append(outputSeparator); sb.append(df.format(getLogLikelihood())); sb.append(outputSeparator); sb.append(df.format(getBIC())); return sb.toString(); } private String printHeaders() { StringBuilder sb = new StringBuilder(); sb.append("Iteration"); for (int i=1; i <= particles.length; i++) { sb.append(headerSeparator); sb.append("mu" + i); } sb.append(headerSeparator); sb.append("LogLik"); sb.append(headerSeparator); sb.append("BIC"); return sb.toString(); } public double getTotalProbabilityForDataPoint (double xi) { double sum = 0; for (int i = 0; i < particles.length; i++) { sum = sum+= particles[i].weightedProbabilityOfXGivenMu(xi); } return sum; } public double getBIC () { return 2 * getLogLikelihood() - (particles.length * Math.log(x.length)); } public double getLogLikelihood() { double logSum = 0; for (int i = 0; i < x.length; i++) { logSum += (Math.log(tau) - 0.5 * Math.log (2 * Math.PI * sigma * sigma)); for (int j = 0 ; j < particles.length; j++) { logSum = logSum - (z[i][j] * Math.pow(x[i] - particles[j].mu, 2) / (2 * sigma * sigma)); } } return logSum; } public double getLikelihood() { return Math.exp(getLogLikelihood()); } } /** * @param args * @throws IOException */ public static void main(String[] args) throws Exception { // generate starting values for mu. use min and max final int MAX_NUMPARTICLES = 5; double[] data = parseInputFileToValues(args[0]); Arrays.sort(data); double min = data[0]; double max = data[data.length - 1]; double average = averageOfArray(data); System.out.println("Data Average: " + average); Particles p2 = null; for (int n = 1; n <= MAX_NUMPARTICLES; n++) { double delta = (max - min) / (n + 1); double[] initialMuValues = new double[n]; for (int i = 0 ; i < n; i++) { initialMuValues[i] = min + (i + 1) * delta; } // double[] staticMuValues ={21.0, 46.0, 55}; // if (n == 3) { // initialMuValues = staticMuValues; // } p2 = new Particles(n, initialMuValues, data); p2.runEM(); } System.out.println("Done"); } private static double averageOfArray(double[] data) { double average = 0; for (int i = 0 ; i < data.length; i++) { average += data[i]; } return average / data.length; } private static double[] parseInputFileToValues(String filePath) throws IOException { BufferedReader r = new BufferedReader (new FileReader(filePath)); StringBuilder sb = new StringBuilder(); String l = r.readLine(); while (l != null) { sb.append(l); sb.append(' '); l = r.readLine(); } String[] rawData = sb.toString().split(" "); double[] data = new double[rawData.length]; for (int i = 0; i < rawData.length; i++) { data[i] = Double.parseDouble(rawData[i].trim()); } r.close(); return data; } }