/* $RCSfile$ * $Author$ * $Date$ * $Revision$ * * Copyright (C) 2004-2008 Rajarshi Guha <rajarshi.guha@gmail.com> * * Contact: cdk-devel@lists.sourceforge.net * * This program is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public License * as published by the Free Software Foundation; either version 2.1 * of the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. */ package org.openscience.cdk.qsar.model.R; import org.openscience.cdk.qsar.model.QSARModelException; import java.util.HashMap; /** * A modeling class that provides a PLS regression model. * * When instantiated this class ensures that the R/Java interface has been * initialized. The response and independent variables can be specified at construction * time or via the <code>setParameters</code> method. * The actual fitting procedure is carried out by <code>build</code>. * <P><b>NOTE:</b> For this class to work, you must have the * <a href="http://cran.r-project.org/src/contrib/Descriptions/pls.pcr.html" target="_top">pls.pcr</a> * package installed in your R library. * <p> * When building the PLS model, parameters such as whether cross validation is to be used, the type of * PLS algorithm etc can be specified by making calls to <code>setParameters</code>. This method can also * be used to set a new X matrix for prediction. * The following table lists the parameters that can be set and their * expected types. More detailed information is available in the R documentation. * <center> * <table border=1 cellpadding=5> * <THEAD> * <tr> * <th>Name</th><th>Java Type</th><th>Default</th><th>Notes</th> * </tr> * </thead> * <tbody> * <tr> * <td>X</td><td>Double[][]</td><td>None</td><td>Variables should be in the columns, observations in the rows</td> * </tr> * <tr> * <td>Y</td><td>Double[][]</td><td>None</td><td>Length should be equal to the rows of X. Variables should be in the columns, observations in the rows</td> * </tr> * <tr> * <td>newX</td><td>Double[][]</td><td>None</td><td>A 2D array of values to make predictions for. Variables should be in the columns, observations in the rows</td> * </tr> * <tr> * <td>ncomp</td><td>Integer[]</td><td>{1,rank(X)}</td><td>This can be an array of length 1 or 2. If there is only one element * then only the specified number of latent variables will be assessed during modeling. If 2 values are specified * then the model will use N1 to N2 latent variables where N1 and N2 are the first and second elements respectively</td> * </tr> * <tr> * <td>method</td><td>String</td><td>"SIMPLS"</td><td>The type of PLS algorithm to use (can be SIMPLS or kernelPLS)</td> * </tr> * <tr> * <td>validation</td><td>String</td><td>"none"</td><td>Indicates whether cross validation should be used. To enable cross validation set this to "CV"</td> * </tr> * <tr> * <td>grpsize</td><td>Integer</td><td>0</td><td>The group size for the "CV" validation. By default this is ignored and <code>niter</code> is used to determine the value of this argument</td> * </tr> * <tr> * <td>niter</td><td>Integer</td><td>10</td><td>The number of iterations in the cross-validation. Note that if <code>grpsize</code> is set to a non-zero value then the value of <code>niter</code> will be calculated from the value of <code>grpsize</code></td> * </tr> * <tr> * <td>nlv</td><td>Integer</td><td>None</td><td>The number of latent variables to use during prediction. By default this does not need to be specified and will be obtained from the fitted model</td> * </tr> * </tbody> * </table> * </center> * <p> * In general the <code>getFit*</code> methods provide access to results from the fit and * <code>getPredict*</code> methods provide access to results from the prediction. In case validation is specified * then the results from the CV can be obtained via the <code>getValidation*</code> methods. * The values returned correspond to the various * values returned by the <a href="http://www.maths.lth.se/help/R/.R/library/pls.pcr/html/mvr.html" target="_top">pls</a> and * <a href="http://www.maths.lth.se/help/R/.R/library/pls.pcr/html/mvr.html" target="_top">predict.mvr</a> * functions in R. * <p> * See {@link RModel} for details regarding the R and SJava environment. * * @author Rajarshi Guha * @cdk.require r-project * @cdk.module qsar * @cdk.githash * * @cdk.keyword partial least squares * @cdk.keyword PLS * @cdk.keyword regression * @deprecated */ public class PLSRegressionModel extends RModel { private static int globalID = 0; private int currentID; private PLSRegressionModelFit modelfit = null; private PLSRegressionModelPredict modelpredict = null; private HashMap params = null; private int nvar = 0; private void setDefaults() { this.params.put("ncomp", new Boolean(false)); this.params.put("method", "SIMPLS"); this.params.put("validation", "none"); this.params.put("grpsize", Integer.valueOf(0)); this.params.put("niter", Integer.valueOf(10)); this.params.put("nlv", new Boolean(false)); } /** * Constructs a PLSRegressionModel object. * * The constructor simply instantiates the model ID. Dependent and independent variables * should be set via setParameters(). */ public PLSRegressionModel(){ super(); this.params = new HashMap(); this.currentID = PLSRegressionModel.globalID; PLSRegressionModel.globalID++; this.setModelName("cdkPLSRegressionModel"+this.currentID); this.setDefaults(); } /** * Constructs a PLSRegressionModel object. * * The constructor allows the user to specify the * dependent and independent variables. The length of the dependent variable * array should equal the number of rows of the independent variable matrix. If this * is not the case an exception will be thrown. * * @param xx An array of independent variables. The observations should be in the rows * and the variables should be in the columns * @param yy An array containing the dependent variable * @throws QSARModelException if the number of observations in x and y do not match */ public PLSRegressionModel(double[][] xx, double[] yy) throws QSARModelException{ super(); this.params = new HashMap(); this.currentID = PLSRegressionModel.globalID; PLSRegressionModel.globalID++; this.setModelName("cdkPLSRegressionModel"+this.currentID); this.setDefaults(); int nrow = yy.length; this.nvar = xx[0].length; if (nrow != xx.length) { throw new QSARModelException("The number of values for the dependent variable does not match the number of rows of the design matrix"); } Double[][] x = new Double[nrow][this.nvar]; Double[][] y = new Double[nrow][1]; for (int i = 0; i < nrow; i++) { y[i][1] = new Double(yy[i]); for (int j = 0; j < this.nvar; j++) x[i][j] = new Double(xx[i][j]); } params.put("X", x); params.put("Y", y); } /** * Constructs a PLSRegressionModel object. * * The constructor allows the user to specify the * dependent and independent variables. This constructor will accept a matrix * of Y values. * <p> * The length of the dependent variable * array should equal the number of rows of the independent variable matrix. If this * is not the case an exception will be thrown. * * @param xx An array of independent variables. The observations should be in the rows * and the variables should be in the columns * @param yy A 2D array containing the dependent variable * @throws QSARModelException if the number of observations in x and y do not match */ public PLSRegressionModel(double[][] xx, double[][] yy) throws QSARModelException{ super(); this.params = new HashMap(); this.currentID = PLSRegressionModel.globalID; PLSRegressionModel.globalID++; this.setModelName("cdkPLSRegressionModel"+this.currentID); this.setDefaults(); int nrow = yy.length; int ncoly = yy[0].length; this.nvar = xx[0].length; if (nrow != xx.length) { throw new QSARModelException("The number of values for the dependent variable does not match the number of rows of the design matrix"); } Double[][] x = new Double[nrow][this.nvar]; Double[][] y = new Double[nrow][ncoly]; //Double[] wts = new Double[nrow]; for (int i = 0; i < nrow; i++) { for (int j = 0; j < ncoly; j++) { y[i][j] = new Double(yy[i][j]); } } for (int i = 0; i < nrow; i++) { for (int j = 0; j < this.nvar; j++) x[i][j] = new Double(xx[i][j]); } params.put("X", x); params.put("Y", y); } protected void finalize() { revaluator.voidEval("rm("+this.getModelName()+",pos=1)"); } /** * Fits a PLS model. * * This method calls the R function to fit a PLS model * using the specified dependent and independent variables. If an error * occurs in the R session, an exception is thrown. */ public void build() throws QSARModelException { // lets do some checks in case stuff was set via setParameters() Double[][] x,y; x = (Double[][])this.params.get("X"); y = (Double[][])this.params.get("Y"); if (this.nvar == 0) this.nvar = x[0].length; else { if (y.length != x.length) { throw new QSARModelException("Number of observations does no match number of rows in the design matrix"); } } // lets build the model try { this.modelfit = (PLSRegressionModelFit)revaluator.call("buildPLS", new Object[]{ getModelName(), this.params }); } catch (Exception re) { throw new QSARModelException(re.toString()); } } /** * Uses a fitted model to predict the response for new observations. * * This function uses a previously fitted model to obtain predicted values * for a new set of observations. If the model has not been fitted prior to this * call an exception will be thrown. Use <code>setParameters</code> * to set the values of the independent variable for the new observations. */ public void predict() throws QSARModelException { if (this.modelfit == null) throw new QSARModelException("Before calling predict() you must fit the model using build()"); Double[][] newx = (Double[][])this.params.get(new String("newX")); if (newx[0].length != this.nvar) { throw new QSARModelException("Number of independent variables used for prediction must match those used for fitting"); } try { this.modelpredict = (PLSRegressionModelPredict)revaluator.call("predictPLS", new Object[]{ getModelName(), this.params }); } catch (Exception re) { throw new QSARModelException(re.toString()); } } /** * Loads a PLSRegressionModel object from disk in to the current session. * * @param fileName The disk file containing the model * @throws QSARModelException if the model being loaded is not a PLS regression model * object */ public void loadModel(String fileName) throws QSARModelException { // should probably check that the filename does exist Object model = (Object)revaluator.call("loadModel", new Object[]{ (Object)fileName }); String modelName = (String)revaluator.call("loadModel.getName", new Object[] { (Object)fileName }); if (model.getClass().getName().equals("org.openscience.cdk.qsar.model.R.PLSRegressionModelFit")) { this.modelfit = (PLSRegressionModelFit)model; this.setModelName(modelName); } else throw new QSARModelException("The loaded model was not a PLSRegressionModel"); } /** * Loads an PLSRegressionModel object from a serialized string into the current session. * * @param serializedModel A String containing the serialized version of the model * @param modelName A String indicating the name of the model in the R session * @throws QSARModelException if the model being loaded is not a PLS regression model * object */ public void loadModel(String serializedModel, String modelName) throws QSARModelException { // should probably check that the fileName does exist Object model = (Object)revaluator.call("unserializeModel", new Object[]{ (Object)serializedModel, (Object)modelName }); String modelname = modelName; if (model.getClass().getName().equals("org.openscience.cdk.qsar.model.R.PLSRegressionModelFit")) { this.modelfit =(PLSRegressionModelFit)model; this.setModelName(modelname); } else throw new QSARModelException("The loaded model was not a PLSRegressionModel"); } /** * Sets parameters required for building a PLS model or using one for prediction. * * This function allows the caller to set the various parameters available * for the pls() and predict.mvr() R routines. See the R help pages for the details of the available * parameters. * * @param key A String containing the name of the parameter as described in the * R help pages * @param obj An Object containing the value of the parameter * @throws QSARModelException if the type of the supplied value does not match the * expected type */ public void setParameters(String key, Object obj) throws QSARModelException { // since we know the possible values of key we should check the coresponding // objects and throw errors if required. Note that this checking can't really check // for values (such as number of variables in the X matrix to build the model and the // X matrix to make new predictions) - these should be checked in functions that will // use these parameters. The main checking done here is for the class of obj and // some cases where the value of obj is not dependent on what is set before it if (key.equals("Y")) { if (!(obj instanceof Double[])) { throw new QSARModelException("The class of the 'Y' object must be Double[][]"); } } if (key.equals("X")) { if (!(obj instanceof Double[][])) { throw new QSARModelException("The class of the 'X' object must be Double[][]"); } } if (key.equals("method")) { if (!(obj instanceof String)) { throw new QSARModelException("The class of the 'method' object must be String"); } if (!(obj.equals("SIMPLS") || obj.equals("kernelPLS"))) { throw new QSARModelException("The value of method must be: SIMPLS or kernelPLS "); } } if (key.equals("validation")) { if (!(obj instanceof String)) { throw new QSARModelException("The class of the 'validation' object must be String"); } if (!(obj.equals("none") || obj.equals("CV"))) { throw new QSARModelException("The value of validation must be: none or CV"); } } if (key.equals("newX")) { if ( !(obj instanceof Double[][])) { throw new QSARModelException("The class of the 'newX' object must be Double[][]"); } } if (key.equals("grpsize")) { if (!(obj instanceof Integer)) { throw new QSARModelException("The class of the 'grpsize' object must be Integer"); } } if (key.equals("niter")) { if (!(obj instanceof Integer)) { throw new QSARModelException("The class of the 'niter' object must be Integer"); } } if (key.equals("nlv")) { if (!(obj instanceof Integer)) { throw new QSARModelException("The class of the 'nlv' object must be Integer"); } } if (key.equals("ncomp")) { if (!(obj instanceof Integer[])) { throw new QSARModelException("The class of the 'ncomp' object must be Integer[]"); } Integer[] tmp = (Integer[])obj; if (tmp.length != 1 && tmp.length != 2) { throw new QSARModelException("The 'ncomp' array can have a length of 1 or 2. See documentation"); } } this.params.put(key,obj); } /* interface to fit object */ /** * The method used to build the PLS model. * * @return String containing 'SIMPLS' or 'kernelPLS' */ public String getFitMethod() { return(this.modelfit.getMethod()); } /** * Returns the fit NComp value. * * @return An array of integers indicating the number of components * (latent variables) */ public int[] getFitNComp() { return(this.modelfit.getNComp()); } /** * Gets the coefficents. * * The return value is a 3D array. The first dimension corresponds * to the specific number of LV's (1 or 2 or 3 and so on). The second * dimension corresponds to the independent variables and the third * dimension corresponds to the Y variables. * * @return double[][][] containing the coefficients */ public double[][][] getFitB() { return(this.modelfit.getB()); } /** * Get the Root Mean Square (RMS) error for the fit. * * @return A 2-dimensional array of RMS errors. */ public double[][] getFitRMS() { return(this.modelfit.getTrainingRMS()); } /** * Get the predicted Y's. * * Each set of latent variables is used to make predictions for all the * Y variables. * * @return A 3-dimensional array of doubles. The first dimension corresponds * to the set of latent variables and the remaining two correspond to the * Y's themselves. */ public double[][][] getFitYPred() { return(this.modelfit.getTrainingYPred()); } /** * Get the X loadings. * * @return A 2-dimensional array of doubles containing the X loadings */ public double[][] getFitXLoading() { return(this.modelfit.getXLoading()); } /** * Get the Y loadings. * * @return A 2-dimensional array of doubles containing the Y loadings */ public double[][] getFitYLoading() { return(this.modelfit.getYLoading()); } /** * Get the X scores. * * @return A 2-dimensional array of doubles containing the X scores */ public double[][] getFitXScores() { return(this.modelfit.getXScores()); } /** * Get the Y scores. * * @return A 2-dimensional array of doubles containing the Y scores */ public double[][] getFitYScores() { return(this.modelfit.getYScores()); } /** * Indicates whether CV was used to build the model. * * @return A boolean indicating whether CV was used */ public boolean getFitWasValidated() { return(this.modelfit.wasValidated()); } /** * The number of iterations used during CV. * * @return An int value indicating the number of iterations in CV */ public int getValidationIter() { return(this.modelfit.getValidationIter()); } /** * The number of latent variables suggested by CV. * * @return An int value indicating the number of LV's */ public int getValidationLV() { return(this.modelfit.getValidationLV()); } /** * Get the R^2 value for validation. * * @return A 2-dimensional array of doubles */ public double[][] getValidationR2() { return(this.modelfit.getValidationR2()); } /** * Get the RMS value for validation. * * @return A 2-dimensional array of doubles */ public double[][] getValidationRMS() { return(this.modelfit.getValidationRMS()); } /** * Get the standard deviation of the RMS errrors for validation. * * @return A 2-dimensional array of doubles */ public double[][] getValidationRMSsd() { return(this.modelfit.getValidationRMSSD()); } /** * Get the predicted Y values from validation. * * @return A 2-dimensional array of doubles */ public double[][][] getValidationYPred() { return(this.modelfit.getValidationYPred()); } /* interface to predict object */ /** * Returns the predicted values for the prediction set. * * This function only returns meaningful results if the <code>predict</code> * method of this class has been called. * * @return A double[][] containing the predicted values */ public double[][] getPredictPredicted() { return(this.modelpredict.getPredictions()); } }