package func;
import java.util.Arrays;
import shared.DataSet;
import shared.Instance;
import dist.*;
import dist.Distribution;
import dist.MixtureDistribution;
import dist.DiscreteDistribution;
import dist.MultivariateGaussian;
/**
* An em clusterer
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class EMClusterer extends AbstractConditionalDistribution implements FunctionApproximater {
/**
* The tolerance
*/
private static final double TOLERANCE = 1E-6;
/**
* The tolerance
*/
private static final int MAX_ITERATIONS = 1000;
/**
* The mixture distribution
*/
private MixtureDistribution mixture;
/**
* The number of clusters
*/
private int k;
/**
* The threshold
*/
private double tolerance;
/**
* The max iterations
*/
private int maxIterations;
/**
* How many iterations it took
*/
private int iterations;
/**
* Whether to print stuff
*/
private boolean debug = false;
/**
* Make a new em clusterer
* @param k the number of clusters
* @param tolerance the tolerance
*/
public EMClusterer(int k, double tolerance, int maxIterations) {
this.k = k;
this.tolerance = tolerance;
this.maxIterations = maxIterations;
}
/**
* Make a new clusterer
*/
public EMClusterer() {
this(2, TOLERANCE, MAX_ITERATIONS);
}
/**
* @see func.Classifier#classDistribution(shared.Instance)
*/
public Distribution distributionFor(Instance instance) {
// calculate the log probs
double[] probs = new double[mixture.getComponents().length];
double maxLog = Double.NEGATIVE_INFINITY;
for (int i = 0; i < probs.length; i++) {
probs[i] = mixture.getComponents()[i].logp(instance);
maxLog = Math.max(maxLog, probs[i]);
}
// turn into real probs
double sum = 0;
for (int i = 0; i < probs.length; i++) {
probs[i] = Math.exp(probs[i] - maxLog);
sum += probs[i];
}
// normalize
for (int i = 0; i < probs.length; i++) {
probs[i] /= sum;
}
return new DiscreteDistribution(probs);
}
/**
* @see func.FunctionApproximater#estimate(shared.DataSet)
*/
public void estimate(DataSet set) {
// kmeans initialization
KMeansClusterer kmeans = new KMeansClusterer(k);
kmeans.estimate(set);
double[] prior = new double[k];
double weightSum = 0;
int[] counts = new int[k];
int[] classifications = new int[set.size()];
for (int i = 0; i < set.size(); i++) {
classifications[i] = kmeans.value(set.get(i)).getDiscrete();
counts[classifications[i]]++;
prior[classifications[i]] += set.get(i).getWeight();
weightSum += set.get(i).getWeight();
}
// create data sets for each of the classes
Instance[][] instances = new Instance[k][];
for (int i = 0; i < instances.length; i++) {
instances[i] = new Instance[counts[i]];
}
Arrays.fill(counts, 0);
for (int i = 0; i < set.size(); i++) {
instances[classifications[i]][counts[classifications[i]]] = set.get(i);
counts[classifications[i]]++;
}
MultivariateGaussian[] initial = new MultivariateGaussian[k];
for (int i = 0; i < initial.length; i++) {
initial[i] = new MultivariateGaussian();
initial[i].setDebug(debug);
initial[i].estimate(new DataSet(instances[i]));
prior[i] /= weightSum;
}
mixture = new MixtureDistribution(initial, prior);
// reestimate
boolean done = false;
double lastLogLikelihood = 0;
iterations = 0;
while (!done) {
if (debug) {
System.out.println("On iteration " + iterations);
System.out.println(mixture);
}
mixture.estimate(set);
double logLikelihood = 0;
for (int j = 0; j < set.size(); j++) {
logLikelihood += mixture.logp(set.get(j));
}
logLikelihood /= set.size();
done = (iterations > 0 && Math.abs(logLikelihood - lastLogLikelihood) < tolerance)
|| (iterations + 1 >= maxIterations);
lastLogLikelihood = logLikelihood;
iterations++;
}
}
/**
* @see func.FunctionApproximater#value(shared.Instance)
*/
public Instance value(Instance i) {
return distributionFor(i).mode();
}
/**
* Get the number of iterations it took
* @return the number
*/
public int getIterations() {
return iterations;
}
/**
* Is debug mode on
* @return true if it is
*/
public boolean isDebug() {
return debug;
}
/**
* Set debug mode on or off
* @param b the debug mode
*/
public void setDebug(boolean b) {
debug = b;
}
/**
* Get the mixture
* @return the mixture
*/
public MixtureDistribution getMixture() {
return mixture;
}
/**
* @see java.lang.Object#toString()
*/
public String toString() {
return mixture.toString();
}
}