import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Random;
public class h3 {
public final static double SIGMA = 1.0;
public static class Particle {
public double mu;
public double sigma;
public double tau;
public Particle(double sigma, double mu, double tau) {
this.mu = mu;
this.sigma = sigma;
this.tau = tau;
}
public double probabilityOfXGivenMu(double xi) {
//=EXP(0-($B11-E$8)^2/2)/SQRT(2*3.14)
return Math.exp(-1 * Math.pow((xi - mu)/sigma, 2) / 2) / Math.sqrt(2 * Math.PI * sigma * sigma);
}
public double weightedProbabilityOfXGivenMu(double xi) {
return tau * probabilityOfXGivenMu(xi);
}
}
public static class Particles {
public Particle[] particles;
public double[] x;
public double[][] z;
public double sigma;
public double tau;
private String headerSeparator = ",";// " ";
private String outputSeparator = ",";//" ";
private Random rand;
public Particles (int numParticles, double[] initMuValues, double[] x) {
rand = new Random();
this.particles = new Particle[numParticles];
this.z = new double[x.length][numParticles];
this.x = x;
for (int i = 0; i < numParticles; i++) {
sigma = SIGMA;
tau = 1.0 / numParticles;
particles[i] = new Particle(sigma, initMuValues[i], tau);
}
}
private void runExpectationPhase () {
for (int i = 0; i < x.length; i++) {
double totalProb = getTotalProbabilityForDataPoint(x[i]);
for (int j = 0; j < particles.length; j++) {
z[i][j] = particles[j].weightedProbabilityOfXGivenMu(x[i]) / totalProb;
}
}
}
private void runMaximizationPhase() {
for (int j = 0; j < particles.length; j++) {
double n = 0;
double d = 0;
for (int i = 0; i < x.length; i++) {
n += (z[i][j] * x[i]);
d += z[i][j];
}
particles[j].mu = n / d;
}
}
public void runEM() {
double bestLogLikelihood = Double.MIN_VALUE;
String bestOutput = null;
int numAttempts = 0;
while (numAttempts < 100) {
StringBuilder sb = new StringBuilder();
int i = 0;
// print headers
String headers = printHeaders();
sb.append(headers);
sb.append('\n');
sb.append(printCurrentMuValues(i));
sb.append('\n');
double prevMuTotal = getLogLikelihood();
double currMuTotal = prevMuTotal;
if (numAttempts != 0) {
this.generateRandomStartValues();
}
// print each row
do {
prevMuTotal = currMuTotal;
runExpectationPhase();
runMaximizationPhase();
i++;
sb.append(printCurrentMuValues(i));
sb.append('\n');
currMuTotal = getLogLikelihood();
} while ((Math.abs(currMuTotal - prevMuTotal) > 0.001) && i < 200) ;
// stop when the change in mu is less than 0.01
sb.append(printValuesAndProb());
if ((numAttempts==0) || this.getLogLikelihood() > bestLogLikelihood) {
if (numAttempts!=0) {
double delta = this.getLogLikelihood() - bestLogLikelihood;
//System.out.println("Found a better log Value: " + delta);
}
bestLogLikelihood = this.getLogLikelihood();
bestOutput = sb.toString();
}
numAttempts++;
}
System.out.println(bestOutput);
}
private void generateRandomStartValues() {
double min = x[0];
double max = x[x.length - 1];
double delta = max - min;
for (int i = 0 ; i < particles.length; i++) {
particles[i].mu = rand.nextDouble() * delta + min;
}
}
private String printValuesAndProb() {
StringBuilder sb = new StringBuilder();
DecimalFormat df = new DecimalFormat("0.00E0");
sb.append("x_i");
for (int i=1; i <= particles.length; i++) {
sb.append(headerSeparator);
sb.append(String.format("P(mu%d | x_i)", i));
}
sb.append('\n');
for (int i = 0; i < Math.min(x.length, 25); i++) {
sb.append(x[i]);
for (int j = 0; j < particles.length; j++) {
sb.append(outputSeparator);
//sb.append(particles[j].probabilityOfXGivenMu(x[i]));
sb.append(df.format(z[i][j]));
//sb.append(df.format(particles[j].probabilityOfXGivenMu(x[i])));
}
sb.append('\n');
}
return sb.toString();
}
private String printCurrentMuValues(int iteration) {
DecimalFormat df = new DecimalFormat("#.##");
StringBuilder sb = new StringBuilder();
sb.append(iteration);
for (int i=1; i <= particles.length; i++) {
sb.append(outputSeparator);
sb.append(df.format(particles[i-1].mu));
}
sb.append(outputSeparator);
sb.append(df.format(getLogLikelihood()));
sb.append(outputSeparator);
sb.append(df.format(getBIC()));
return sb.toString();
}
private String printHeaders() {
StringBuilder sb = new StringBuilder();
sb.append("Iteration");
for (int i=1; i <= particles.length; i++) {
sb.append(headerSeparator);
sb.append("mu" + i);
}
sb.append(headerSeparator);
sb.append("LogLik");
sb.append(headerSeparator);
sb.append("BIC");
return sb.toString();
}
public double getTotalProbabilityForDataPoint (double xi) {
double sum = 0;
for (int i = 0; i < particles.length; i++) {
sum = sum+= particles[i].weightedProbabilityOfXGivenMu(xi);
}
return sum;
}
public double getBIC () {
return 2 * getLogLikelihood() - (particles.length * Math.log(x.length));
}
public double getLogLikelihood() {
double logSum = 0;
for (int i = 0; i < x.length; i++) {
logSum += (Math.log(tau) - 0.5 * Math.log (2 * Math.PI * sigma * sigma));
for (int j = 0 ; j < particles.length; j++) {
logSum = logSum - (z[i][j] * Math.pow(x[i] - particles[j].mu, 2) / (2 * sigma * sigma));
}
}
return logSum;
}
public double getLikelihood() {
return Math.exp(getLogLikelihood());
}
}
/**
* @param args
* @throws IOException
*/
public static void main(String[] args) throws Exception {
// generate starting values for mu. use min and max
final int MAX_NUMPARTICLES = 5;
double[] data = parseInputFileToValues(args[0]);
Arrays.sort(data);
double min = data[0];
double max = data[data.length - 1];
double average = averageOfArray(data);
System.out.println("Data Average: " + average);
Particles p2 = null;
for (int n = 1; n <= MAX_NUMPARTICLES; n++) {
double delta = (max - min) / (n + 1);
double[] initialMuValues = new double[n];
for (int i = 0 ; i < n; i++) {
initialMuValues[i] = min + (i + 1) * delta;
}
// double[] staticMuValues ={21.0, 46.0, 55};
// if (n == 3) {
// initialMuValues = staticMuValues;
// }
p2 = new Particles(n, initialMuValues, data);
p2.runEM();
}
System.out.println("Done");
}
private static double averageOfArray(double[] data) {
double average = 0;
for (int i = 0 ; i < data.length; i++) {
average += data[i];
}
return average / data.length;
}
private static double[] parseInputFileToValues(String filePath) throws IOException {
BufferedReader r = new BufferedReader (new FileReader(filePath));
StringBuilder sb = new StringBuilder();
String l = r.readLine();
while (l != null) {
sb.append(l);
sb.append(' ');
l = r.readLine();
}
String[] rawData = sb.toString().split(" ");
double[] data = new double[rawData.length];
for (int i = 0; i < rawData.length; i++) {
data[i] = Double.parseDouble(rawData[i].trim());
}
r.close();
return data;
}
}