/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* PLSFilter.java
* Copyright (C) 2006 University of Waikato, Hamilton, New Zealand
*
*/
package weka.filters.supervised.attribute;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.DenseInstance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.Capabilities.Capability;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.matrix.EigenvalueDecomposition;
import weka.core.matrix.Matrix;
import weka.filters.Filter;
import weka.filters.SimpleBatchFilter;
import weka.filters.SupervisedFilter;
import weka.filters.unsupervised.attribute.Center;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Standardize;
import java.util.Enumeration;
import java.util.Vector;
/**
<!-- globalinfo-start -->
* Runs Partial Least Square Regression over the given instances and computes the resulting beta matrix for prediction.<br/>
* By default it replaces missing values and centers the data.<br/>
* <br/>
* For more information see:<br/>
* <br/>
* Tormod Naes, Tomas Isaksson, Tom Fearn, Tony Davies (2002). A User Friendly Guide to Multivariate Calibration and Classification. NIR Publications.<br/>
* <br/>
* StatSoft, Inc.. Partial Least Squares (PLS).<br/>
* <br/>
* Bent Jorgensen, Yuri Goegebeur. Module 7: Partial least squares regression I.<br/>
* <br/>
* S. de Jong (1993). SIMPLS: an alternative approach to partial least squares regression. Chemometrics and Intelligent Laboratory Systems. 18:251-263.
* <p/>
<!-- globalinfo-end -->
*
<!-- technical-bibtex-start -->
* BibTeX:
* <pre>
* @book{Naes2002,
* author = {Tormod Naes and Tomas Isaksson and Tom Fearn and Tony Davies},
* publisher = {NIR Publications},
* title = {A User Friendly Guide to Multivariate Calibration and Classification},
* year = {2002},
* ISBN = {0-9528666-2-5}
* }
*
* @misc{missing_id,
* author = {StatSoft, Inc.},
* booktitle = {Electronic Textbook StatSoft},
* title = {Partial Least Squares (PLS)},
* HTTP = {http://www.statsoft.com/textbook/stpls.html}
* }
*
* @misc{missing_id,
* author = {Bent Jorgensen and Yuri Goegebeur},
* booktitle = {ST02: Multivariate Data Analysis and Chemometrics},
* title = {Module 7: Partial least squares regression I},
* HTTP = {http://statmaster.sdu.dk/courses/ST02/module07/}
* }
*
* @article{Jong1993,
* author = {S. de Jong},
* journal = {Chemometrics and Intelligent Laboratory Systems},
* pages = {251-263},
* title = {SIMPLS: an alternative approach to partial least squares regression},
* volume = {18},
* year = {1993}
* }
* </pre>
* <p/>
<!-- technical-bibtex-end -->
*
<!-- options-start -->
* Valid options are: <p/>
*
* <pre> -D
* Turns on output of debugging information.</pre>
*
* <pre> -C <num>
* The number of components to compute.
* (default: 20)</pre>
*
* <pre> -U
* Updates the class attribute as well.
* (default: off)</pre>
*
* <pre> -M
* Turns replacing of missing values on.
* (default: off)</pre>
*
* <pre> -A <SIMPLS|PLS1>
* The algorithm to use.
* (default: PLS1)</pre>
*
* <pre> -P <none|center|standardize>
* The type of preprocessing that is applied to the data.
* (default: center)</pre>
*
<!-- options-end -->
*
* @author FracPete (fracpete at waikato dot ac dot nz)
* @version $Revision: 5987 $
*/
public class PLSFilter
extends SimpleBatchFilter
implements SupervisedFilter, TechnicalInformationHandler {
/** for serialization */
static final long serialVersionUID = -3335106965521265631L;
/** the type of algorithm: SIMPLS */
public static final int ALGORITHM_SIMPLS = 1;
/** the type of algorithm: PLS1 */
public static final int ALGORITHM_PLS1 = 2;
/** the types of algorithm */
public static final Tag[] TAGS_ALGORITHM = {
new Tag(ALGORITHM_SIMPLS, "SIMPLS"),
new Tag(ALGORITHM_PLS1, "PLS1")
};
/** the type of preprocessing: None */
public static final int PREPROCESSING_NONE = 0;
/** the type of preprocessing: Center */
public static final int PREPROCESSING_CENTER = 1;
/** the type of preprocessing: Standardize */
public static final int PREPROCESSING_STANDARDIZE = 2;
/** the types of preprocessing */
public static final Tag[] TAGS_PREPROCESSING = {
new Tag(PREPROCESSING_NONE, "none"),
new Tag(PREPROCESSING_CENTER, "center"),
new Tag(PREPROCESSING_STANDARDIZE, "standardize")
};
/** the maximum number of components to generate */
protected int m_NumComponents = 20;
/** the type of algorithm */
protected int m_Algorithm = ALGORITHM_PLS1;
/** the regression vector "r-hat" for PLS1 */
protected Matrix m_PLS1_RegVector = null;
/** the P matrix for PLS1 */
protected Matrix m_PLS1_P = null;
/** the W matrix for PLS1 */
protected Matrix m_PLS1_W = null;
/** the b-hat vector for PLS1 */
protected Matrix m_PLS1_b_hat = null;
/** the W matrix for SIMPLS */
protected Matrix m_SIMPLS_W = null;
/** the B matrix for SIMPLS (used for prediction) */
protected Matrix m_SIMPLS_B = null;
/** whether to include the prediction, i.e., modifying the class attribute */
protected boolean m_PerformPrediction = false;
/** for replacing missing values */
protected Filter m_Missing = null;
/** whether to replace missing values */
protected boolean m_ReplaceMissing = true;
/** for centering the data */
protected Filter m_Filter = null;
/** the type of preprocessing */
protected int m_Preprocessing = PREPROCESSING_CENTER;
/** the mean of the class */
protected double m_ClassMean = 0;
/** the standard deviation of the class */
protected double m_ClassStdDev = 0;
/**
* default constructor
*/
public PLSFilter() {
super();
// setup pre-processing
m_Missing = new ReplaceMissingValues();
m_Filter = new Center();
}
/**
* Returns a string describing this classifier.
*
* @return a description of the classifier suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return
"Runs Partial Least Square Regression over the given instances "
+ "and computes the resulting beta matrix for prediction.\n"
+ "By default it replaces missing values and centers the data.\n\n"
+ "For more information see:\n\n"
+ getTechnicalInformation().toString();
}
/**
* Returns an instance of a TechnicalInformation object, containing
* detailed information about the technical background of this class,
* e.g., paper reference or book this class is based on.
*
* @return the technical information about this class
*/
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result;
TechnicalInformation additional;
result = new TechnicalInformation(Type.BOOK);
result.setValue(Field.AUTHOR, "Tormod Naes and Tomas Isaksson and Tom Fearn and Tony Davies");
result.setValue(Field.YEAR, "2002");
result.setValue(Field.TITLE, "A User Friendly Guide to Multivariate Calibration and Classification");
result.setValue(Field.PUBLISHER, "NIR Publications");
result.setValue(Field.ISBN, "0-9528666-2-5");
additional = result.add(Type.MISC);
additional.setValue(Field.AUTHOR, "StatSoft, Inc.");
additional.setValue(Field.TITLE, "Partial Least Squares (PLS)");
additional.setValue(Field.BOOKTITLE, "Electronic Textbook StatSoft");
additional.setValue(Field.HTTP, "http://www.statsoft.com/textbook/stpls.html");
additional = result.add(Type.MISC);
additional.setValue(Field.AUTHOR, "Bent Jorgensen and Yuri Goegebeur");
additional.setValue(Field.TITLE, "Module 7: Partial least squares regression I");
additional.setValue(Field.BOOKTITLE, "ST02: Multivariate Data Analysis and Chemometrics");
additional.setValue(Field.HTTP, "http://statmaster.sdu.dk/courses/ST02/module07/");
additional = result.add(Type.ARTICLE);
additional.setValue(Field.AUTHOR, "S. de Jong");
additional.setValue(Field.YEAR, "1993");
additional.setValue(Field.TITLE, "SIMPLS: an alternative approach to partial least squares regression");
additional.setValue(Field.JOURNAL, "Chemometrics and Intelligent Laboratory Systems");
additional.setValue(Field.VOLUME, "18");
additional.setValue(Field.PAGES, "251-263");
return result;
}
/**
* Gets an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector result;
Enumeration enm;
String param;
SelectedTag tag;
int i;
result = new Vector();
enm = super.listOptions();
while (enm.hasMoreElements())
result.addElement(enm.nextElement());
result.addElement(new Option(
"\tThe number of components to compute.\n"
+ "\t(default: 20)",
"C", 1, "-C <num>"));
result.addElement(new Option(
"\tUpdates the class attribute as well.\n"
+ "\t(default: off)",
"U", 0, "-U"));
result.addElement(new Option(
"\tTurns replacing of missing values on.\n"
+ "\t(default: off)",
"M", 0, "-M"));
param = "";
for (i = 0; i < TAGS_ALGORITHM.length; i++) {
if (i > 0)
param += "|";
tag = new SelectedTag(TAGS_ALGORITHM[i].getID(), TAGS_ALGORITHM);
param += tag.getSelectedTag().getReadable();
}
result.addElement(new Option(
"\tThe algorithm to use.\n"
+ "\t(default: PLS1)",
"A", 1, "-A <" + param + ">"));
param = "";
for (i = 0; i < TAGS_PREPROCESSING.length; i++) {
if (i > 0)
param += "|";
tag = new SelectedTag(TAGS_PREPROCESSING[i].getID(), TAGS_PREPROCESSING);
param += tag.getSelectedTag().getReadable();
}
result.addElement(new Option(
"\tThe type of preprocessing that is applied to the data.\n"
+ "\t(default: center)",
"P", 1, "-P <" + param + ">"));
return result.elements();
}
/**
* returns the options of the current setup
*
* @return the current options
*/
public String[] getOptions() {
int i;
Vector result;
String[] options;
result = new Vector();
options = super.getOptions();
for (i = 0; i < options.length; i++)
result.add(options[i]);
result.add("-C");
result.add("" + getNumComponents());
if (getPerformPrediction())
result.add("-U");
if (getReplaceMissing())
result.add("-M");
result.add("-A");
result.add("" + getAlgorithm().getSelectedTag().getReadable());
result.add("-P");
result.add("" + getPreprocessing().getSelectedTag().getReadable());
return (String[]) result.toArray(new String[result.size()]);
}
/**
* Parses the options for this object. <p/>
*
<!-- options-start -->
* Valid options are: <p/>
*
* <pre> -D
* Turns on output of debugging information.</pre>
*
* <pre> -C <num>
* The number of components to compute.
* (default: 20)</pre>
*
* <pre> -U
* Updates the class attribute as well.
* (default: off)</pre>
*
* <pre> -M
* Turns replacing of missing values on.
* (default: off)</pre>
*
* <pre> -A <SIMPLS|PLS1>
* The algorithm to use.
* (default: PLS1)</pre>
*
* <pre> -P <none|center|standardize>
* The type of preprocessing that is applied to the data.
* (default: center)</pre>
*
<!-- options-end -->
*
* @param options the options to use
* @throws Exception if the option setting fails
*/
public void setOptions(String[] options) throws Exception {
String tmpStr;
super.setOptions(options);
tmpStr = Utils.getOption("C", options);
if (tmpStr.length() != 0)
setNumComponents(Integer.parseInt(tmpStr));
else
setNumComponents(20);
setPerformPrediction(Utils.getFlag("U", options));
setReplaceMissing(Utils.getFlag("M", options));
tmpStr = Utils.getOption("A", options);
if (tmpStr.length() != 0)
setAlgorithm(new SelectedTag(tmpStr, TAGS_ALGORITHM));
else
setAlgorithm(new SelectedTag(ALGORITHM_PLS1, TAGS_ALGORITHM));
tmpStr = Utils.getOption("P", options);
if (tmpStr.length() != 0)
setPreprocessing(new SelectedTag(tmpStr, TAGS_PREPROCESSING));
else
setPreprocessing(new SelectedTag(PREPROCESSING_CENTER, TAGS_PREPROCESSING));
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numComponentsTipText() {
return "The number of components to compute.";
}
/**
* sets the maximum number of attributes to use.
*
* @param value the maximum number of attributes
*/
public void setNumComponents(int value) {
m_NumComponents = value;
}
/**
* returns the maximum number of attributes to use.
*
* @return the current maximum number of attributes
*/
public int getNumComponents() {
return m_NumComponents;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String performPredictionTipText() {
return "Whether to update the class attribute with the predicted value.";
}
/**
* Sets whether to update the class attribute with the predicted value.
*
* @param value if true the class value will be replaced by the
* predicted value.
*/
public void setPerformPrediction(boolean value) {
m_PerformPrediction = value;
}
/**
* Gets whether the class attribute is updated with the predicted value.
*
* @return true if the class attribute is updated
*/
public boolean getPerformPrediction() {
return m_PerformPrediction;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String algorithmTipText() {
return "Sets the type of algorithm to use.";
}
/**
* Sets the type of algorithm to use
*
* @param value the algorithm type
*/
public void setAlgorithm(SelectedTag value) {
if (value.getTags() == TAGS_ALGORITHM) {
m_Algorithm = value.getSelectedTag().getID();
}
}
/**
* Gets the type of algorithm to use
*
* @return the current algorithm type.
*/
public SelectedTag getAlgorithm() {
return new SelectedTag(m_Algorithm, TAGS_ALGORITHM);
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String replaceMissingTipText() {
return "Whether to replace missing values.";
}
/**
* Sets whether to replace missing values.
*
* @param value if true missing values are replaced with the
* ReplaceMissingValues filter.
*/
public void setReplaceMissing(boolean value) {
m_ReplaceMissing = value;
}
/**
* Gets whether missing values are replace.
*
* @return true if missing values are replaced with the
* ReplaceMissingValues filter
*/
public boolean getReplaceMissing() {
return m_ReplaceMissing;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String preprocessingTipText() {
return "Sets the type of preprocessing to use.";
}
/**
* Sets the type of preprocessing to use
*
* @param value the preprocessing type
*/
public void setPreprocessing(SelectedTag value) {
if (value.getTags() == TAGS_PREPROCESSING) {
m_Preprocessing = value.getSelectedTag().getID();
}
}
/**
* Gets the type of preprocessing to use
*
* @return the current preprocessing type.
*/
public SelectedTag getPreprocessing() {
return new SelectedTag(m_Preprocessing, TAGS_PREPROCESSING);
}
/**
* Determines the output format based on the input format and returns
* this. In case the output format cannot be returned immediately, i.e.,
* immediateOutputFormat() returns false, then this method will be called
* from batchFinished().
*
* @param inputFormat the input format to base the output format on
* @return the output format
* @throws Exception in case the determination goes wrong
* @see #hasImmediateOutputFormat()
* @see #batchFinished()
*/
protected Instances determineOutputFormat(Instances inputFormat)
throws Exception {
// generate header
FastVector atts = new FastVector();
String prefix = getAlgorithm().getSelectedTag().getReadable();
for (int i = 0; i < getNumComponents(); i++)
atts.addElement(new Attribute(prefix + "_" + (i+1)));
atts.addElement(new Attribute("Class"));
Instances result = new Instances(prefix, atts, 0);
result.setClassIndex(result.numAttributes() - 1);
return result;
}
/**
* returns the data minus the class column as matrix
*
* @param instances the data to work on
* @return the data without class attribute
*/
protected Matrix getX(Instances instances) {
double[][] x;
double[] values;
Matrix result;
int i;
int n;
int j;
int clsIndex;
clsIndex = instances.classIndex();
x = new double[instances.numInstances()][];
for (i = 0; i < instances.numInstances(); i++) {
values = instances.instance(i).toDoubleArray();
x[i] = new double[values.length - 1];
j = 0;
for (n = 0; n < values.length; n++) {
if (n != clsIndex) {
x[i][j] = values[n];
j++;
}
}
}
result = new Matrix(x);
return result;
}
/**
* returns the data minus the class column as matrix
*
* @param instance the instance to work on
* @return the data without the class attribute
*/
protected Matrix getX(Instance instance) {
double[][] x;
double[] values;
Matrix result;
x = new double[1][];
values = instance.toDoubleArray();
x[0] = new double[values.length - 1];
System.arraycopy(values, 0, x[0], 0, values.length - 1);
result = new Matrix(x);
return result;
}
/**
* returns the data class column as matrix
*
* @param instances the data to work on
* @return the class attribute
*/
protected Matrix getY(Instances instances) {
double[][] y;
Matrix result;
int i;
y = new double[instances.numInstances()][1];
for (i = 0; i < instances.numInstances(); i++)
y[i][0] = instances.instance(i).classValue();
result = new Matrix(y);
return result;
}
/**
* returns the data class column as matrix
*
* @param instance the instance to work on
* @return the class attribute
*/
protected Matrix getY(Instance instance) {
double[][] y;
Matrix result;
y = new double[1][1];
y[0][0] = instance.classValue();
result = new Matrix(y);
return result;
}
/**
* returns the X and Y matrix again as Instances object, based on the given
* header (must have a class attribute set).
*
* @param header the format of the instance object
* @param x the X matrix (data)
* @param y the Y matrix (class)
* @return the assembled data
*/
protected Instances toInstances(Instances header, Matrix x, Matrix y) {
double[] values;
int i;
int n;
Instances result;
int rows;
int cols;
int offset;
int clsIdx;
result = new Instances(header, 0);
rows = x.getRowDimension();
cols = x.getColumnDimension();
clsIdx = header.classIndex();
for (i = 0; i < rows; i++) {
values = new double[cols + 1];
offset = 0;
for (n = 0; n < values.length; n++) {
if (n == clsIdx) {
offset--;
values[n] = y.get(i, 0);
}
else {
values[n] = x.get(i, n + offset);
}
}
result.add(new DenseInstance(1.0, values));
}
return result;
}
/**
* returns the given column as a vector (actually a n x 1 matrix)
*
* @param m the matrix to work on
* @param columnIndex the column to return
* @return the column as n x 1 matrix
*/
protected Matrix columnAsVector(Matrix m, int columnIndex) {
Matrix result;
int i;
result = new Matrix(m.getRowDimension(), 1);
for (i = 0; i < m.getRowDimension(); i++)
result.set(i, 0, m.get(i, columnIndex));
return result;
}
/**
* stores the data from the (column) vector in the matrix at the specified
* index
*
* @param v the vector to store in the matrix
* @param m the receiving matrix
* @param columnIndex the column to store the values in
*/
protected void setVector(Matrix v, Matrix m, int columnIndex) {
m.setMatrix(0, m.getRowDimension() - 1, columnIndex, columnIndex, v);
}
/**
* returns the (column) vector of the matrix at the specified index
*
* @param m the matrix to work on
* @param columnIndex the column to get the values from
* @return the column vector
*/
protected Matrix getVector(Matrix m, int columnIndex) {
return m.getMatrix(0, m.getRowDimension() - 1, columnIndex, columnIndex);
}
/**
* determines the dominant eigenvector for the given matrix and returns it
*
* @param m the matrix to determine the dominant eigenvector for
* @return the dominant eigenvector
*/
protected Matrix getDominantEigenVector(Matrix m) {
EigenvalueDecomposition eigendecomp;
double[] eigenvalues;
int index;
Matrix result;
eigendecomp = m.eig();
eigenvalues = eigendecomp.getRealEigenvalues();
index = Utils.maxIndex(eigenvalues);
result = columnAsVector(eigendecomp.getV(), index);
return result;
}
/**
* normalizes the given vector (inplace)
*
* @param v the vector to normalize
*/
protected void normalizeVector(Matrix v) {
double sum;
int i;
// determine length
sum = 0;
for (i = 0; i < v.getRowDimension(); i++)
sum += v.get(i, 0) * v.get(i, 0);
sum = StrictMath.sqrt(sum);
// normalize content
for (i = 0; i < v.getRowDimension(); i++)
v.set(i, 0, v.get(i, 0) / sum);
}
/**
* processes the instances using the PLS1 algorithm
*
* @param instances the data to process
* @return the modified data
* @throws Exception in case the processing goes wrong
*/
protected Instances processPLS1(Instances instances) throws Exception {
Matrix X, X_trans, x;
Matrix y;
Matrix W, w;
Matrix T, t, t_trans;
Matrix P, p, p_trans;
double b;
Matrix b_hat;
int i;
int j;
Matrix X_new;
Matrix tmp;
Instances result;
Instances tmpInst;
// initialization
if (!isFirstBatchDone()) {
// split up data
X = getX(instances);
y = getY(instances);
X_trans = X.transpose();
// init
W = new Matrix(instances.numAttributes() - 1, getNumComponents());
P = new Matrix(instances.numAttributes() - 1, getNumComponents());
T = new Matrix(instances.numInstances(), getNumComponents());
b_hat = new Matrix(getNumComponents(), 1);
for (j = 0; j < getNumComponents(); j++) {
// 1. step: wj
w = X_trans.times(y);
normalizeVector(w);
setVector(w, W, j);
// 2. step: tj
t = X.times(w);
t_trans = t.transpose();
setVector(t, T, j);
// 3. step: ^bj
b = t_trans.times(y).get(0, 0) / t_trans.times(t).get(0, 0);
b_hat.set(j, 0, b);
// 4. step: pj
p = X_trans.times(t).times((double) 1 / t_trans.times(t).get(0, 0));
p_trans = p.transpose();
setVector(p, P, j);
// 5. step: Xj+1
X = X.minus(t.times(p_trans));
y = y.minus(t.times(b));
}
// W*(P^T*W)^-1
tmp = W.times(((P.transpose()).times(W)).inverse());
// X_new = X*W*(P^T*W)^-1
X_new = getX(instances).times(tmp);
// factor = W*(P^T*W)^-1 * b_hat
m_PLS1_RegVector = tmp.times(b_hat);
// save matrices
m_PLS1_P = P;
m_PLS1_W = W;
m_PLS1_b_hat = b_hat;
if (getPerformPrediction())
result = toInstances(getOutputFormat(), X_new, y);
else
result = toInstances(getOutputFormat(), X_new, getY(instances));
}
// prediction
else {
result = new Instances(getOutputFormat());
for (i = 0; i < instances.numInstances(); i++) {
// work on each instance
tmpInst = new Instances(instances, 0);
tmpInst.add((Instance) instances.instance(i).copy());
x = getX(tmpInst);
X = new Matrix(1, getNumComponents());
T = new Matrix(1, getNumComponents());
for (j = 0; j < getNumComponents(); j++) {
setVector(x, X, j);
// 1. step: tj = xj * wj
t = x.times(getVector(m_PLS1_W, j));
setVector(t, T, j);
// 2. step: xj+1 = xj - tj*pj^T (tj is 1x1 matrix!)
x = x.minus(getVector(m_PLS1_P, j).transpose().times(t.get(0, 0)));
}
if (getPerformPrediction())
tmpInst = toInstances(getOutputFormat(), T, T.times(m_PLS1_b_hat));
else
tmpInst = toInstances(getOutputFormat(), T, getY(tmpInst));
result.add(tmpInst.instance(0));
}
}
return result;
}
/**
* processes the instances using the SIMPLS algorithm
*
* @param instances the data to process
* @return the modified data
* @throws Exception in case the processing goes wrong
*/
protected Instances processSIMPLS(Instances instances) throws Exception {
Matrix A, A_trans;
Matrix M;
Matrix X, X_trans;
Matrix X_new;
Matrix Y, y;
Matrix C, c;
Matrix Q, q;
Matrix W, w;
Matrix P, p, p_trans;
Matrix v, v_trans;
Matrix T;
Instances result;
int h;
if (!isFirstBatchDone()) {
// init
X = getX(instances);
X_trans = X.transpose();
Y = getY(instances);
A = X_trans.times(Y);
M = X_trans.times(X);
C = Matrix.identity(instances.numAttributes() - 1, instances.numAttributes() - 1);
W = new Matrix(instances.numAttributes() - 1, getNumComponents());
P = new Matrix(instances.numAttributes() - 1, getNumComponents());
Q = new Matrix(1, getNumComponents());
for (h = 0; h < getNumComponents(); h++) {
// 1. qh as dominant EigenVector of Ah'*Ah
A_trans = A.transpose();
q = getDominantEigenVector(A_trans.times(A));
// 2. wh=Ah*qh, ch=wh'*Mh*wh, wh=wh/sqrt(ch), store wh in W as column
w = A.times(q);
c = w.transpose().times(M).times(w);
w = w.times(1.0 / StrictMath.sqrt(c.get(0, 0)));
setVector(w, W, h);
// 3. ph=Mh*wh, store ph in P as column
p = M.times(w);
p_trans = p.transpose();
setVector(p, P, h);
// 4. qh=Ah'*wh, store qh in Q as column
q = A_trans.times(w);
setVector(q, Q, h);
// 5. vh=Ch*ph, vh=vh/||vh||
v = C.times(p);
normalizeVector(v);
v_trans = v.transpose();
// 6. Ch+1=Ch-vh*vh', Mh+1=Mh-ph*ph'
C = C.minus(v.times(v_trans));
M = M.minus(p.times(p_trans));
// 7. Ah+1=ChAh (actually Ch+1)
A = C.times(A);
}
// finish
m_SIMPLS_W = W;
T = X.times(m_SIMPLS_W);
X_new = T;
m_SIMPLS_B = W.times(Q.transpose());
if (getPerformPrediction())
y = T.times(P.transpose()).times(m_SIMPLS_B);
else
y = getY(instances);
result = toInstances(getOutputFormat(), X_new, y);
}
else {
result = new Instances(getOutputFormat());
X = getX(instances);
X_new = X.times(m_SIMPLS_W);
if (getPerformPrediction())
y = X.times(m_SIMPLS_B);
else
y = getY(instances);
result = toInstances(getOutputFormat(), X_new, y);
}
return result;
}
/**
* Returns the Capabilities of this filter.
*
* @return the capabilities of this object
* @see Capabilities
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
// class
result.enable(Capability.NUMERIC_CLASS);
result.enable(Capability.DATE_CLASS);
return result;
}
/**
* Processes the given data (may change the provided dataset) and returns
* the modified version. This method is called in batchFinished().
*
* @param instances the data to process
* @return the modified data
* @throws Exception in case the processing goes wrong
* @see #batchFinished()
*/
protected Instances process(Instances instances) throws Exception {
Instances result;
int i;
double clsValue;
double[] clsValues;
result = null;
// save original class values if no prediction is performed
if (!getPerformPrediction())
clsValues = instances.attributeToDoubleArray(instances.classIndex());
else
clsValues = null;
if (!isFirstBatchDone()) {
// init filters
if (m_ReplaceMissing)
m_Missing.setInputFormat(instances);
switch (m_Preprocessing) {
case PREPROCESSING_CENTER:
m_ClassMean = instances.meanOrMode(instances.classIndex());
m_ClassStdDev = 1;
m_Filter = new Center();
((Center) m_Filter).setIgnoreClass(true);
break;
case PREPROCESSING_STANDARDIZE:
m_ClassMean = instances.meanOrMode(instances.classIndex());
m_ClassStdDev = StrictMath.sqrt(instances.variance(instances.classIndex()));
m_Filter = new Standardize();
((Standardize) m_Filter).setIgnoreClass(true);
break;
default:
m_ClassMean = 0;
m_ClassStdDev = 1;
m_Filter = null;
}
if (m_Filter != null)
m_Filter.setInputFormat(instances);
}
// filter data
if (m_ReplaceMissing)
instances = Filter.useFilter(instances, m_Missing);
if (m_Filter != null)
instances = Filter.useFilter(instances, m_Filter);
switch (m_Algorithm) {
case ALGORITHM_SIMPLS:
result = processSIMPLS(instances);
break;
case ALGORITHM_PLS1:
result = processPLS1(instances);
break;
default:
throw new IllegalStateException(
"Algorithm type '" + m_Algorithm + "' is not recognized!");
}
// add the mean to the class again if predictions are to be performed,
// otherwise restore original class values
for (i = 0; i < result.numInstances(); i++) {
if (!getPerformPrediction()) {
result.instance(i).setClassValue(clsValues[i]);
}
else {
clsValue = result.instance(i).classValue();
result.instance(i).setClassValue(clsValue*m_ClassStdDev + m_ClassMean);
}
}
return result;
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 5987 $");
}
/**
* runs the filter with the given arguments.
*
* @param args the commandline arguments
*/
public static void main(String[] args) {
runFilter(new PLSFilter(), args);
}
}