/* $RCSfile$
* $Author$
* $Date$
* $Revision$
*
* Copyright (C) 2004-2008 Rajarshi Guha <rajarshi.guha@gmail.com>
*
* Contact: cdk-devel@lists.sourceforge.net
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public License
* as published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*/
package org.openscience.cdk.qsar.model.R2;
import java.io.File;
import java.util.HashMap;
import org.openscience.cdk.qsar.model.QSARModelException;
import org.openscience.cdk.tools.ILoggingTool;
import org.openscience.cdk.tools.LoggingToolFactory;
import org.rosuda.JRI.REXP;
import org.rosuda.JRI.RList;
/**
* A modeling class that provides a linear least squares regression model.
* <p/>
* When instantiated this class ensures that the R/Java interface has been
* initialized. The response and independent variables can be specified at construction
* time or via the <code>setParameters</code> method. The actual fitting procedure is carried out by <code>build</code> after which
* the model may be used to make predictions.
* <p/>
* Currently, the design of the class is quite sparse as it does not allow subsetting,
* variable names, setting of contrasts and so on.
* It is also assumed that the values of all the variables are defined (i.e., not such that
* they are <a href="http://stat.ethz.ch/R-manual/R-patched/library/base/html/NA.html">NA</a>
* in an R session).
* The use of
* this class is shown in the following code snippet
* <pre>
* double[][] x;
* double[] y;
* try {
* LinearRegressionModel lrm = new LinearRegressionModel(x,y);
* lrm.build();
* lrm.setParameters("newdata", newx);
* lrm.setParameters("interval", "confidence");
* lrm.predict();
* } catch (QSARModelException qme) {
* System.out.println(qme.toString());
* }
* double[] fitted = lrm.getFittedValues()
* double[] predicted = lrm.getModelPredict().asList.at("fit").asDoubleArray();
* </pre>
* Note that when making predictions, the new X matrix and interval type can be set by calls
* to setParameters(). In general, the arguments for lm() and predict.lm() can be set via
* calls to setParameters(). The following table lists the parameters that can be set and their
* expected types. More detailed informationis available in the R documentation.
* <center>
* <table border=1 cellpadding=5>
* <THEAD>
* <tr>
* <th>Name</th><th>Java Type</th><th>Notes</th>
* </tr>
* </thead>
* <tbody>
* <tr>
* <td>x</td><td>Double[][]</td><td></td>
* </tr>
* <tr>
* <td>y</td><td>Double[]</td><td>Length should be equal to the rows of x</td>
* </tr>
* <tr>
* <td>weights</td><td>Double[]</td><td>Length should be equal to rows of x</td>
* </tr>
* <tr>
* <td>newdata</td><td>Double[][]</td><td>Number of columns should be the same as in x</td>
* </tr>
* <tr>
* <td>interval</td><td>String</td><td>Can be 'confidence' or 'predicton'</td>
* </tr>
* </tbody>
* </table>
* </center>
* In general the <code>getFit*</code> methods provide access to results from the fit
* and <code>getPredict*</code> methods provide access to results from the prediction (i.e.,
* prediction using the model on new data). The values returned correspond to the various
* values returned by the <a href="http://stat.ethz.ch/R-manual/R-patched/library/stats/html/lm.html">lm</a>
* and <a href="http://stat.ethz.ch/R-manual/R-patched/library/stats/html/predict.lm.html">predict.lm</a>
* functions in R.
* <p/>
* See {@link RModel} for details regarding the R and rJava environment.
*
* @author Rajarshi Guha
* @cdk.require r-project
* @cdk.module qsar
* @cdk.githash
* @cdk.keyword linear regression
* @cdk.keyword R
*/
public class LinearRegressionModel extends org.openscience.cdk.qsar.model.R2.RModel {
private static int globalID = 0;
private int nvar = 0;
private RList modelPredict = null;
private static ILoggingTool logger =
LoggingToolFactory.createLoggingTool(LinearRegressionModel.class);
/**
* Constructs a LinearRegressionModel object.
* <p/>
* The constructor simply instantiates the model ID. Dependent and independent variables
* should be set via setParameters().
* <p/>
* An important feature of the current implementation is that <i>all</i> the
* independent variables are used during the fit. Furthermore no subsetting is possible.
* As a result when setting these via setParameters() the caller should specify only
* the variables and observations that will be used for the fit.
*/
public LinearRegressionModel() throws QSARModelException {
super();
params = new HashMap();
int currentID = LinearRegressionModel.globalID;
org.openscience.cdk.qsar.model.R2.LinearRegressionModel.globalID++;
this.setModelName("cdkLMModel" + currentID);
}
/**
* Constructs a LinearRegressionModel object.
* <p/>
* The constructor allows the user to specify the
* dependent and independent variables. The length of the dependent variable
* array should equal the number of rows of the independent variable matrix. If this
* is not the case an exception will be thrown.
* <p/>
* An important feature of the current implementation is that <i>all</i> the
* independent variables are used during the fit. Furthermore no subsetting is possible.
* As a result when creating an instance of this object the caller should specify only
* the variables and observations that will be used for the fit.
*
* @param xx An array of independent variables. The observations should be in the rows
* and the variables should be in the columns
* @param yy an array containing the dependent variable
* @throws org.openscience.cdk.qsar.model.QSARModelException
* if the number of observations in x and y do not match
*/
public LinearRegressionModel(double[][] xx, double[] yy) throws QSARModelException {
super();
params = new HashMap();
int currentID = LinearRegressionModel.globalID;
LinearRegressionModel.globalID++;
this.setModelName("cdkLMModel" + currentID);
int nrow = yy.length;
this.nvar = xx[0].length;
if (nrow != xx.length) {
throw new QSARModelException("The number of values for the dependent variable does not match the number of rows of the design matrix");
}
Double[][] x = new Double[nrow][this.nvar];
Double[] y = new Double[nrow];
Double[] weights = new Double[nrow];
for (int i = 0; i < nrow; i++) {
y[i] = new Double(yy[i]);
weights[i] = new Double(1.0);
}
for (int i = 0; i < nrow; i++) {
for (int j = 0; j < this.nvar; j++)
x[i][j] = new Double(xx[i][j]);
}
params.put("x", x);
params.put("y", y);
params.put("weights", weights);
}
/**
* Constructs a LinearRegressionModel object.
* <p/>
* The constructor allows the user to specify the
* dependent and independent variables as well as weightings for
* the observations.
* <p/>
* The length of the dependent variable
* array should equal the number of rows of the independent variable matrix. If this
* is not the case an exception will be thrown.
* <p/>
* An important feature of the current implementation is that <i>all</i> the
* independent variables are used during the fit. Furthermore no subsetting is possible.
* As a result when creating an instance of this object the caller should specify only
* the variables and observations that will be used for the fit.
*
* @param xx An array of independent variables. The observations should be in the rows
* and the variables should be in the columns
* @param yy an array containing the dependent variable
* @param weights Specifies the weights for each observation. Unit weights are equivilant
* to OLS
* @throws org.openscience.cdk.qsar.model.QSARModelException
* if the number of observations in x and y do not match
*/
public LinearRegressionModel(double[][] xx, double[] yy, double[] weights) throws QSARModelException {
super();
params = new HashMap();
int currentID = LinearRegressionModel.globalID;
org.openscience.cdk.qsar.model.R2.LinearRegressionModel.globalID++;
this.setModelName("cdkLMModel" + currentID);
int nrow = yy.length;
this.nvar = xx[0].length;
if (nrow != xx.length) {
throw new QSARModelException("The number of values for the dependent variable does not match the number of rows of the design matrix");
}
if (nrow != weights.length) {
throw new QSARModelException("The length of the weight vector does not match the number of rows of the design matrix");
}
Double[][] x = new Double[nrow][this.nvar];
Double[] y = new Double[nrow];
Double[] wts = new Double[nrow];
for (int i = 0; i < nrow; i++) {
y[i] = new Double(yy[i]);
wts[i] = new Double(weights[i]);
}
for (int i = 0; i < nrow; i++) {
for (int j = 0; j < this.nvar; j++)
x[i][j] = new Double(xx[i][j]);
}
params.put("x", x);
params.put("y", y);
params.put("weights", wts);
}
/**
* Fits a linear regression model.
* <p/>
* This method calls the R function to fit a linear regression model
* to the specified dependent and independent variables. If an error
* occurs in the R session, an exception is thrown.
* <p/>
* Note that, this method should be called prior to calling the various get
* methods to obtain information regarding the fit.
*/
public void build() throws QSARModelException {
// lets do some checks in case stuff was set via setParameters()
Double[][] x;
Double[] y, weights;
x = (Double[][]) this.params.get("x");
y = (Double[]) this.params.get("y");
weights = (Double[]) this.params.get("weights");
if (this.nvar == 0) this.nvar = x[0].length;
else {
if (y.length != x.length) {
throw new QSARModelException("Number of observations does no match number of rows in the design matrix");
}
if (weights.length != y.length) {
throw new QSARModelException("The weight vector must have the same length as the number of observations");
}
}
// lets build the model
String paramVarName = loadParametersIntoRSession();
String cmd = "buildLM(\"" + getModelName() + "\", " + paramVarName + ")";
REXP ret = rengine.eval(cmd);
if (ret == null) {
logger.debug("Error in buildLM");
throw new QSARModelException("Error in buildLM");
}
// remove the parameter list
rengine.eval("rm(" + paramVarName + ")");
// save the model object on the Java side
modelObject = ret.asList();
}
/**
* Sets parameters required for building a linear model or using one for prediction.
* <p/>
* This function allows the caller to set the various parameters available
* for the lm() and predict.lm() R routines. See the R help pages for the details of the available
* parameters.
*
* @param key A String containing the name of the parameter as described in the
* R help pages
* @param obj An Object containing the value of the parameter
* @throws org.openscience.cdk.qsar.model.QSARModelException
* if the type of the supplied value does not match the
* expected type
*/
public void setParameters(String key, Object obj) throws QSARModelException {
// since we know the possible values of key we should check the coresponding
// objects and throw errors if required. Note that this checking can't really check
// for values (such as number of variables in the X matrix to build the model and the
// X matrix to make new predictions) - these should be checked in functions that will
// use these parameters. The main checking done here is for the class of obj and
// some cases where the value of obj is not dependent on what is set before it
if (key.equals("y")) {
if (!(obj instanceof Double[])) {
throw new QSARModelException("The class of the 'y' object must be Double[]");
}
}
if (key.equals("x")) {
if (!(obj instanceof Double[][])) {
throw new QSARModelException("The class of the 'x' object must be Double[][]");
}
}
if (key.equals("weights")) {
if (!(obj instanceof Double[])) {
throw new QSARModelException("The class of the 'weights' object must be Double[]");
}
}
if (key.equals("interval")) {
if (!(obj instanceof String)) {
throw new QSARModelException("The class of the 'interval' object must be String");
}
if (!(obj.equals("confidence") || obj.equals("prediction"))) {
throw new QSARModelException("The type of interval must be: prediction or confidence");
}
}
if (key.equals("newdata")) {
if (!(obj instanceof Double[][])) {
throw new QSARModelException("The class of the 'newdata' object must be Double[][]");
}
}
this.params.put(key, obj);
}
/**
* Uses a fitted model to predict the response for new observations.
* <p/>
* This function uses a previously fitted model to obtain predicted values
* for a new set of observations. If the model has not been fitted prior to this
* call an exception will be thrown. Use <code>setParameters</code>
* to set the values of the independent variable for the new observations and the
* interval type.
*
* @throws org.openscience.cdk.qsar.model.QSARModelException
* if the model has not been built prior to a call
* to this method. Also if the number of independent variables specified for prediction
* is not the same as specified during model building
*/
public void predict() throws QSARModelException {
if (modelObject == null)
throw new QSARModelException("Before calling predict() you must fit the model using build()");
Double[][] newx = (Double[][]) params.get("newdata");
if (newx[0].length != nvar) {
throw new QSARModelException("Number of independent variables used for prediction must match those used for fitting");
}
String pn = loadParametersIntoRSession();
REXP ret = rengine.eval("predictLM(\"" + getModelName() + "\", " + pn + ")");
if (ret == null) throw new QSARModelException("Error occured in prediction");
// remove the parameter list
rengine.eval("rm(" + pn + ")");
modelPredict = ret.asList();
}
/**
* Get the R object obtained from <code>predict.lm()</code>.
*
* @return The result of the prediction. Contains a number of fields corresponding to
* predicted values, SE and other items depending on the parameters that we set.
* Note that the call to <code>predict.lm()</code> is performde with <code>se.fit = TRUE</code>
*/
public RList getModelPredict() {
return modelPredict;
}
/**
* Returns an <code>RList</code> object summarizing the linear regression model.
* <p/>
* The return object can be queried via the <code>RList</code> methods to extract the
* required components.
*
* @return A summary for the linear regression model
* @throws org.openscience.cdk.qsar.model.QSARModelException
* if the model has not been built prior to a call
* to this method
*/
public RList summary() throws QSARModelException {
if (modelObject == null)
throw new QSARModelException("Before calling summary() you must fit the model using build()");
REXP ret = rengine.eval("summary(" + getModelName() + ")");
if (ret == null) {
logger.debug("Error in summary()");
throw new QSARModelException("Error in summary()");
}
return ret.asList();
}
/**
* Loads an LinearRegressionModel object from disk in to the current session.
*
* @param fileName The disk file containing the model
* @throws org.openscience.cdk.qsar.model.QSARModelException
* if the model being loaded is not a linear regression model
* object or the file does not exist
*/
public void loadModel(String fileName) throws QSARModelException {
File f = new File(fileName);
if (!f.exists()) throw new QSARModelException(fileName + " does not exist");
rengine.assign("tmpFileName", fileName);
REXP ret = rengine.eval("loadModel(tmpFileName)");
if (ret == null) throw new QSARModelException("Model could not be loaded");
String name = ret.asList().at("name").asString();
if (!isOfClass(name, "lm")) {
removeObject(name);
throw new QSARModelException("Loaded object was not of class \'lm\'");
}
modelObject = ret.asList().at("model").asList();
setModelName(name);
nvar = getCoefficients().length - 1; // since the intercept is also returned
}
/**
* Loads an LinearRegressionModel object from a serialized string into the current session.
*
* @param serializedModel A String containing the serialized version of the model
* @param modelName A String indicating the name of the model in the R session
* @throws org.openscience.cdk.qsar.model.QSARModelException
* if the model being loaded is not a linear regression model
* object
*/
public void loadModel(String serializedModel, String modelName) throws QSARModelException {
rengine.assign("tmpSerializedModel", serializedModel);
rengine.assign("tmpModelName", modelName);
REXP ret = rengine.eval("unserializeModel(tmpSerializedModel, tmpModelName)");
if (ret == null) throw new QSARModelException("Model could not be unserialized");
String name = ret.asList().at("name").asString();
if (!isOfClass(name, "lm")) {
removeObject(name);
throw new QSARModelException("Loaded object was not of class \'lm\'");
}
modelObject = ret.asList().at("model").asList();
setModelName(name);
nvar = getCoefficients().length - 1; // as the intercept is also returned
}
// Autogenerated code: assumes that 'modelObject' is
// a RList object
/**
* Gets the <code>assign</code> field of an <code>'lm'</code> object.
*
* @return The value of the assign field
*/
public int[] getAssign() {
return modelObject.at("assign").asIntArray();
}
/**
* Gets the <code>coefficients</code> field of an <code>'lm'</code> object.
*
* @return The value of the coefficients field
*/
public double[] getCoefficients() {
return modelObject.at("coefficients").asDoubleArray();
}
/**
* Gets the <code>df.residual</code> field of an <code>'lm'</code> object.
*
* @return The value of the df.residual field
*/
public int getDfResidual() {
return modelObject.at("df.residual").asInt();
}
/**
* Gets the <code>effects</code> field of an <code>'lm'</code> object.
*
* @return The value of the effects field
*/
public double[] getEffects() {
return modelObject.at("effects").asDoubleArray();
}
/**
* Gets the <code>fitted.values</code> field of an <code>'lm'</code> object.
*
* @return The value of the fitted.values field
*/
public double[] getFittedValues() {
return modelObject.at("fitted.values").asDoubleArray();
}
/**
* Gets the <code>model</code> field of an <code>'lm'</code> object.
*
* @return The value of the model field
*/
public RList getModel() {
return modelObject.at("model").asList();
}
/**
* Gets the <code>qr</code> field of an <code>'lm'</code> object.
*
* @return The value of the qr field
*/
public RList getQr() {
return modelObject.at("qr").asList();
}
/**
* Gets the <code>rank</code> field of an <code>'lm'</code> object.
*
* @return The value of the rank field
*/
public int getRank() {
return modelObject.at("rank").asInt();
}
/**
* Gets the <code>residuals</code> field of an <code>'lm'</code> object.
*
* @return The value of the residuals field
*/
public double[] getResiduals() {
return modelObject.at("residuals").asDoubleArray();
}
/**
* Gets the <code>xlevels</code> field of an <code>'lm'</code> object.
*
* @return The value of the xlevels field
*/
public RList getXlevels() {
return modelObject.at("xlevels").asList();
}
}