package shared.filt;
import dist.MultivariateGaussian;
import shared.DataSet;
import shared.Instance;
import util.linalg.Matrix;
import util.linalg.RectangularMatrix;
import util.linalg.SymmetricEigenvalueDecomposition;
import util.linalg.Vector;
/**
* A filter that performs PCA on a set of data
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class PrincipalComponentAnalysis implements ReversibleFilter {
/**
* The default threshold
*/
private static final double THRESHOLD = 1E-6;
/**
* The projection matrix
*/
private Matrix projection;
/**
* The eigen value matrix
*/
private Matrix eigenValues;
/**
* The mean vector
*/
private Vector mean;
/**
* Make a new PCA filter
* @param toKeep the number of components to keep
* @param dataSet the set form which to estimate components
*/
public PrincipalComponentAnalysis(DataSet dataSet, int toKeep, double threshold) {
MultivariateGaussian mg = new MultivariateGaussian();
mg.estimate(dataSet);
Matrix covarianceMatrix = mg.getCovarianceMatrix();
mean = mg.getMean();
if (toKeep == -1) {
toKeep = mean.size();
}
SymmetricEigenvalueDecomposition sed =
new SymmetricEigenvalueDecomposition(covarianceMatrix);
Matrix eigenVectors = sed.getU();
eigenValues = sed.getD();
int aboveThreshold = 0;
while (aboveThreshold < toKeep &&
eigenValues.get(aboveThreshold, aboveThreshold) > threshold) {
aboveThreshold++;
}
toKeep = Math.min(toKeep, aboveThreshold);
projection = new RectangularMatrix(toKeep, eigenVectors.m());
for (int i = 0; i < toKeep; i++) {
projection.setRow(i, eigenVectors.getColumn(i));
}
}
/**
* Make a new PCA filter
* @param numberOfComponents the number to keep
* @param set the data set to estimate components from
*/
public PrincipalComponentAnalysis(DataSet set, int numberOfComponents) {
this(set, numberOfComponents, THRESHOLD);
}
/**
* Make a new PCA filter
* @param set the data set to estimate components from
*/
public PrincipalComponentAnalysis(DataSet set) {
this(set, -1);
}
/**
* @see shared.filt.DataSetFilter#filter(shared.DataSet)
*/
public void filter(DataSet dataSet) {
for (int i = 0; i < dataSet.size(); i++) {
Instance instance = dataSet.get(i);
instance.setData(instance.getData().minus(mean));
instance.setData(projection.times(instance.getData()));
}
dataSet.setDescription(null);
}
/**
* @see shared.filt.ReversibleFilter#reverse(shared.DataSet)
*/
public void reverse(DataSet dataSet) {
for (int i = 0; i < dataSet.size(); i++) {
Instance instance = dataSet.get(i);
instance.setData(projection.transpose().times(instance.getData()));
instance.setData(instance.getData().plus(mean));
}
dataSet.setDescription(null);
}
/**
* Get the projection matrix used
* @return the projection matrix
*/
public Matrix getProjection() {
return projection;
}
/**
* Get the mean
* @return the mean
*/
public Vector getMean() {
return mean;
}
/**
* Get the eigenvalues
* @return the eigenvalues
*/
public Matrix getEigenValues() {
return eigenValues;
}
}