package shared.filt;
import dist.MultivariateGaussian;
import shared.DataSet;
import shared.DataSetDescription;
import shared.Instance;
import util.linalg.CholeskyFactorization;
import util.linalg.LowerTriangularMatrix;
import util.linalg.Matrix;
import util.linalg.RectangularMatrix;
import util.linalg.SymmetricEigenvalueDecomposition;
import util.linalg.UpperTriangularMatrix;
import util.linalg.Vector;
/**
* A filter that performs fisher linear discriminant
* analysis on a data set
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class LinearDiscriminantAnalysis implements ReversibleFilter {
/**
* The projection matrix
*/
private Matrix projection;
/**
* The mean
*/
private Vector mean;
/**
* Make a new PCA filter
* @param toKeep the number of components to keep
* @param dataSet the set form which to estimate components
*/
public LinearDiscriminantAnalysis(DataSet dataSet) {
// calculate the mean
MultivariateGaussian mg = new MultivariateGaussian();
mg.estimate(dataSet);
mean = mg.getMean();
if (dataSet.getDescription() == null) {
dataSet.setDescription(new DataSetDescription(dataSet));
}
// calculate the class counts and weight sums
int classCount = dataSet.getDescription()
.getLabelDescription().getDiscreteRange();
int toKeep = classCount - 1;
int[] classCounts = new int[classCount];
double[] weightSums = new double[classCount];
double weightSum = 0;
for (int i = 0; i < dataSet.size(); i++) {
int classification = dataSet.get(i).getLabel().getDiscrete();
classCounts[classification]++;
weightSums[classification] += dataSet.get(i).getWeight();
weightSum += dataSet.get(i).getWeight();
}
// normalize the weight sums
for (int i = 0; i < weightSums.length; i++) {
weightSums[i] /= weightSum;
}
// seperate out the data
Instance[][] instances = new Instance[classCount][];
for (int i = 0; i < instances.length; i++) {
instances[i] = new Instance[classCounts[i]];
classCounts[i] = 0;
}
for (int i = 0; i < dataSet.size(); i++) {
int classification = dataSet.get(i).getLabel().getDiscrete();
instances[classification][classCounts[classification]] = dataSet.get(i);
classCounts[classification]++;
}
// the between class covariance matrix
Matrix sb = new RectangularMatrix(mean.size(), mean.size());
// the within class covariance matrix
Matrix sw = new RectangularMatrix(mean.size(), mean.size());
// calculate the two matrices
for (int i = 0; i < classCount; i++) {
mg = new MultivariateGaussian();
mg.estimate(new DataSet(instances[i]));
sw.plusEquals(mg.getCovarianceMatrix().times(weightSums[i]));
Vector classMean = mg.getMean();
Vector classMeanMinusMean = classMean.minus(mean);
sb.plusEquals(classMeanMinusMean.outerProduct(
classMeanMinusMean).times(weightSums[i]));
}
// solve the symmetric-definite generalized eigenvalue problem
CholeskyFactorization cf = new CholeskyFactorization(sw);
LowerTriangularMatrix g = cf.getL();
LowerTriangularMatrix gInverse = g.inverse();
UpperTriangularMatrix gInverseTranspose = (UpperTriangularMatrix) gInverse.transpose();
Matrix c = gInverse.times(sb).times(gInverseTranspose);
SymmetricEigenvalueDecomposition sed = new SymmetricEigenvalueDecomposition(c);
Matrix eigenVectors = gInverseTranspose.times(sed.getU());
// keep the top vectors
projection = new RectangularMatrix(toKeep, eigenVectors.m());
for (int i = 0; i < toKeep; i++) {
Vector v = eigenVectors.getColumn(i);
projection.setRow(i, v.times(1.0/v.norm()));
}
}
/**
* @see shared.filt.DataSetFilter#filter(shared.DataSet)
*/
public void filter(DataSet dataSet) {
for (int i = 0; i < dataSet.size(); i++) {
Instance instance = dataSet.get(i);
instance.setData(instance.getData().minus(mean));
instance.setData(projection.times(instance.getData()));
}
dataSet.setDescription(null);
}
/**
* @see shared.filt.ReversibleFilter#reverse(shared.DataSet)
*/
public void reverse(DataSet dataSet) {
for (int i = 0; i < dataSet.size(); i++) {
Instance instance = dataSet.get(i);
instance.setData(projection.transpose().times(instance.getData()));
instance.setData(instance.getData().plus(mean));
}
dataSet.setDescription(null);
}
/**
* Get the projection matrix used
* @return the projection matrix
*/
public Matrix getProjection() {
return projection;
}
/**
* Get the mean
* @return the mean vector
*/
public Vector getMean() {
return mean;
}
}