/*
* Copyright 2010, 2011 Institut Pasteur.
*
* This file is part of NHerve Main Toolbox, which is an ICY plugin.
*
* NHerve Main Toolbox is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* NHerve Main Toolbox is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with NHerve Main Toolbox. If not, see <http://www.gnu.org/licenses/>.
*/
package plugins.nherve.toolbox.image.feature;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import plugins.nherve.matrix.Matrix;
import plugins.nherve.toolbox.image.feature.signature.DenseVectorSignature;
import plugins.nherve.toolbox.image.feature.signature.SignatureException;
import plugins.nherve.toolbox.image.feature.signature.DefaultVectorSignature;
/**
* The Class LDA.
*
* @author Nicolas HERVE - nicolas.herve@pasteur.fr
*/
public class LDA extends DimensionReductionAlgorithm {
/** The classes. */
private List<Integer> classes;
/** The inv pooled. */
private Matrix invPooled;
/** The const stuff. */
private Matrix constStuff;
/** The classes mean. */
private HashMap<Integer, Matrix> classesMean;
/** The nb groups. */
int nbGroups;
/**
* Instantiates a new lDA.
*
* @param signatures
* the signatures
* @param classes
* the classes
*/
public LDA(List<DefaultVectorSignature> signatures, List<Integer> classes) {
super(signatures);
this.classes = classes;
this.invPooled = null;
this.classesMean = null;
this.constStuff = null;
}
/* (non-Javadoc)
* @see plugins.nherve.toolbox.image.feature.DimensionReductionAlgorithm#compute()
*/
@Override
public void compute() throws SignatureException {
check();
Matrix m = getMatrix(signatures);
Matrix globalMean = getMean(m);
if (isLogEnabled()) {
log("Global mean : ");
globalMean.print(20, 15);
}
HashMap<Integer, Integer> classesCardinality = new HashMap<Integer, Integer>();
for (int idx = 0; idx < m.getRowDimension(); idx++) {
int c = classes.get(idx);
if (classesCardinality.containsKey(c)) {
classesCardinality.put(c, classesCardinality.get(c) + 1);
} else {
classesCardinality.put(c, 1);
}
}
nbGroups = classesCardinality.size();
HashMap<Integer, Matrix> classesMatrix = new HashMap<Integer, Matrix>();
for (int g : classesCardinality.keySet()) {
log("Class " + g + " has " + classesCardinality.get(g) + " members");
classesMatrix.put(g, new Matrix(classesCardinality.get(g), dim));
}
for (int idx = 0; idx < m.getRowDimension(); idx++) {
int c = classes.get(idx);
Matrix mc = classesMatrix.get(c);
int cc = classesCardinality.get(c) - 1;
for (int d = 0; d < dim; d++) {
mc.set(cc, d, m.get(idx, d));
}
classesCardinality.put(c, cc);
}
classesMean = new HashMap<Integer, Matrix>();
for (int g : classesMatrix.keySet()) {
Matrix lm = getMean(classesMatrix.get(g));
if (isLogEnabled()) {
log("Class " + g + " mean : ");
lm.print(20, 15);
}
classesMean.put(g, lm);
}
for (int g : classesMatrix.keySet()) {
Matrix cm = classesMatrix.get(g);
Matrix ones = new Matrix(cm.getRowDimension(), 1, -1);
cm.plusEquals(ones.times(globalMean));
}
HashMap<Integer, Matrix> varcov = new HashMap<Integer, Matrix>();
for (int g : classesMatrix.keySet()) {
Matrix mx = getVarCovMatrix(classesMatrix.get(g));
if (isLogEnabled()) {
log("C"+g+": ");
mx.print(20, 15);
}
varcov.put(g, mx);
}
Matrix pooled = new Matrix(dim, dim, 0);
for (int g : classesMatrix.keySet()) {
Matrix vc = varcov.get(g);
pooled.plusEquals(vc.times(classesMatrix.get(g).getRowDimension()));
}
pooled.timesEquals(1.0 / (double) m.getRowDimension());
if (isLogEnabled()) {
log("Pooled C: ");
pooled.print(20, 15);
}
invPooled = pooled.inverse();
Matrix p = new Matrix(1, nbGroups);
for (int g : classesMatrix.keySet()) {
p.set(0, g, Math.log((double)classesMatrix.get(g).getRowDimension() / (double)m.getRowDimension()));
}
constStuff = new Matrix(1, nbGroups);
for (int g : classesMatrix.keySet()) {
Matrix mn = classesMean.get(g);
double c = p.get(0, g) - 0.5 * ((mn.times(invPooled)).times(mn.transpose())).get(0, 0);
constStuff.set(0, g, c);
}
log("LDA done");
}
/**
* Project.
*
* @param toProject
* the to project
* @return the vector signature
* @throws SignatureException
* the signature exception
*/
public DefaultVectorSignature project(DefaultVectorSignature toProject) throws SignatureException {
DenseVectorSignature vs = new DenseVectorSignature(Math.max(nbGroups, 3));
ArrayList<DefaultVectorSignature> a = new ArrayList<DefaultVectorSignature>();
a.add(toProject);
Matrix x = getMatrix(a);
for (int g = 0; g < nbGroups; g++) {
Matrix mn = classesMean.get(g);
double c = ((mn.times(invPooled)).times(x.transpose())).get(0, 0);
vs.set(g, c + constStuff.get(0, g));
}
return vs;
}
/* (non-Javadoc)
* @see plugins.nherve.toolbox.image.feature.DimensionReductionAlgorithm#project(java.util.List)
*/
@Override
public List<DefaultVectorSignature> project(List<DefaultVectorSignature> toProject) throws SignatureException {
ArrayList<DefaultVectorSignature> proj = new ArrayList<DefaultVectorSignature>();
for (DefaultVectorSignature vs : toProject) {
proj.add(project(vs));
}
return proj;
}
}