/*
* RapidMiner
*
* Copyright (C) 2001-2011 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.bayes;
import Jama.Matrix;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Statistics;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.example.table.NominalMapping;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.math.MathFunctions;
import com.rapidminer.tools.math.matrix.CovarianceMatrix;
/**
* <p>This operator performs a linear discriminant analysis (LDA). This method tries to find the
* linear combination of features which best separate two or more classes of examples. The resulting
* combination is then used as a linear classifier. LDA is closely related to ANOVA (analysis of
* variance) and regression analysis, which also attempt to express one dependent variable as a
* linear combination of other features or measurements. In the other two methods however,
* the dependent variable is a numerical quantity, while for LDA it is a categorical variable
* (i.e. the class label).</p>
*
* <p>LDA is also closely related to principal component analysis (PCA) and factor analysis in
* that both look for linear combinations of variables which best explain the data. LDA explicitly
* attempts to model the difference between the classes of data. PCA on the other hand does not
* take into account any difference in class.</p>
*
* @author Sebastian Land
*/
public class LinearDiscriminantAnalysis extends AbstractLearner {
public LinearDiscriminantAnalysis(OperatorDescription description) {
super(description);
}
public Model learn(ExampleSet exampleSet) throws OperatorException {
int numberOfNumericalAttributes = 0;
for (Attribute attribute: exampleSet.getAttributes()) {
if (attribute.isNumerical()) {
numberOfNumericalAttributes++;
}
}
NominalMapping labelMapping = exampleSet.getAttributes().getLabel().getMapping();
String[] labelValues = new String[labelMapping.size()];
for (int i = 0; i < labelMapping.size(); i++) {
labelValues[i] = labelMapping.mapIndex(i);
}
Matrix[] meanVectors = getMeanVectors(exampleSet, numberOfNumericalAttributes, labelValues);
Matrix[] inverseCovariance = getInverseCovarianceMatrices(exampleSet, labelValues);
return getModel(exampleSet, labelValues, meanVectors, inverseCovariance, getAprioriProbabilities(exampleSet, labelValues));
}
protected DiscriminantModel getModel(ExampleSet exampleSet, String[] labels, Matrix[] meanVectors, Matrix[] inverseCovariances, double[] aprioriProbabilities) throws UndefinedParameterError {
return new DiscriminantModel(exampleSet, labels, meanVectors, inverseCovariances, aprioriProbabilities, 0d);
}
private double[] getAprioriProbabilities(ExampleSet exampleSet, String[] labels) {
double[] aprioriProbabilites = new double[labels.length];
double totalSize = exampleSet.size();
Attribute labelAttribute = exampleSet.getAttributes().getLabel();
SplittedExampleSet labelSet = SplittedExampleSet.splitByAttribute(exampleSet, exampleSet.getAttributes().getLabel());
int labelIndex = 0;
for (String label: labels) {
// select apropriate subset
for (int i = 0; i < labels.length; i++) {
labelSet.selectSingleSubset(i);
if (labelSet.getExample(0).getNominalValue(labelAttribute).equals(label))
break;
}
// calculate apriori Prob
aprioriProbabilites[labelIndex] = labelSet.size() / totalSize;
labelIndex++;
}
return aprioriProbabilites;
}
protected Matrix[] getMeanVectors(ExampleSet exampleSet, int numberOfAttributes, String[] labels) throws UserError {
Matrix[] classMeanVectors = new Matrix[labels.length];
Attribute labelAttribute = exampleSet.getAttributes().getLabel();
SplittedExampleSet labelSet = SplittedExampleSet.splitByAttribute(exampleSet, exampleSet.getAttributes().getLabel());
if (labelSet.getNumberOfSubsets() != labels.length)
throw new UserError(this, 118, labelAttribute, labelSet.getNumberOfSubsets(), 2);
int labelIndex = 0;
for (String label: labels) {
// select apropriate subset
for (int i = 0; i < labels.length; i++) {
labelSet.selectSingleSubset(i);
if (labelSet.getExample(0).getNominalValue(labelAttribute).equals(label))
break;
}
// calculate mean
labelSet.recalculateAllAttributeStatistics();
double[] meanValues = new double[numberOfAttributes];
int i = 0;
for (Attribute attribute: labelSet.getAttributes()) {
if (attribute.isNumerical()) {
meanValues[i] = labelSet.getStatistics(attribute, Statistics.AVERAGE);
}
i++;
}
classMeanVectors[labelIndex] = new Matrix(meanValues, 1);
labelIndex++;
}
return classMeanVectors;
}
protected Matrix[] getInverseCovarianceMatrices(ExampleSet exampleSet, String[] labels) throws UndefinedParameterError {
Matrix[] classInverseCovariances = new Matrix[labels.length];
Matrix inverse = MathFunctions.invertMatrix(CovarianceMatrix.getCovarianceMatrix(exampleSet));
for (int i = 0; i < labels.length; i++)
classInverseCovariances[i] = inverse;
return classInverseCovariances;
}
@Override
public Class<? extends PredictionModel> getModelClass() {
return DiscriminantModel.class;
}
public boolean supportsCapability(OperatorCapability capability) {
if (capability.equals(OperatorCapability.NUMERICAL_ATTRIBUTES))
return true;
if (capability.equals(OperatorCapability.BINOMINAL_LABEL))
return true;
if (capability.equals(OperatorCapability.POLYNOMINAL_LABEL))
return true;
return false;
}
}