/*
*
* YAQP - Yet Another QSAR Project:
* Machine Learning algorithms designed for the prediction of toxicological
* features of chemical compounds become available on the Web. Yaqp is developed
* under OpenTox (http://opentox.org) which is an FP7-funded EU research project.
* This project was developed at the Automatic Control Lab in the Chemical Engineering
* School of National Technical University of Athens. Please read README for more
* information.
*
* Copyright (C) 2009-2010 Pantelis Sopasakis & Charalampos Chomenides
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
* Contact:
* Pantelis Sopasakis
* chvng@mail.ntua.gr
* Address: Iroon Politechniou St. 9, Zografou, Athens Greece
* tel. +30 210 7723236
*/
package org.opentox.qsar.processors.trainers.regression;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import org.opentox.config.Configuration;
import org.opentox.config.ServerFolders;
import org.opentox.core.exceptions.Cause;
import org.opentox.core.exceptions.YaqpException;
import org.opentox.io.exceptions.YaqpIOException;
import org.opentox.core.processors.Pipeline;
import org.opentox.db.exceptions.DbException;
import org.opentox.db.handlers.WriterHandler;
import org.opentox.db.util.TheDbConnector;
import org.opentox.io.processors.InputProcessor;
import org.opentox.io.util.ServerList;
import org.opentox.ontology.components.Feature;
import org.opentox.ontology.components.QSARModel;
import org.opentox.ontology.components.QSARModel.ModelStatus;
import org.opentox.ontology.components.User;
import org.opentox.ontology.data.DatasetBuilder;
import org.opentox.ontology.processors.InstancesProcessor;
import org.opentox.ontology.util.AlgorithmParameter;
import org.opentox.ontology.util.YaqpAlgorithms;
import org.opentox.qsar.exceptions.QSARException;
import org.opentox.qsar.processors.filters.AttributeCleanup;
import org.opentox.qsar.processors.filters.AttributeCleanup.ATTRIBUTE_TYPE;
import org.opentox.qsar.processors.filters.InstancesFilter;
import org.opentox.qsar.processors.filters.SimpleMVHFilter;
import org.opentox.qsar.processors.trainers.WekaTrainer;
import org.opentox.util.logging.YaqpLogger;
import org.opentox.util.logging.levels.Trace;
import org.opentox.www.rest.components.YaqpForm;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LinearRegression;
import weka.core.Instances;
import weka.core.converters.ArffSaver;
/**
*
* This MLR Trainer accepts the training data as instances and produces a model file
* which is saved in the corresponding folder on the server for weka serialized models.
* What is more, a PMML file is generated and stored as well.
* @author Pantelis Sopasakis
* @author Charalampos Chomenides
*/
@SuppressWarnings({"unchecked"})
public class MLRTrainer extends WekaRegressor {
private static final String PMMLIntro =
"<PMML version=\"3.2\" "
+ " xmlns=\"http://www.dmg.org/PMML-3_2\" "
+ " xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\">\n"
+ " <Header copyright=\"Copyleft (c) OpenTox - An Open Source "
+ "Predictive Toxicology Framework, http://www.opentox.org, 2009\" />\n";
/**
* <p>Construct a new MLR trainer which is initialized with a set of parameters as
* posted by the client to the resource (instance of {@link YaqpForm }. Some
* checks are done regarding the consisteny of the parameters. The mandatory parameters
* one has to specidy are the <code>prediction_feature</code> and the <code>
* dataset_uri</code> which should be posted by the client to the training resoruce.
* </p>
* @param parameters
* Set of parameters as an instance of {@link YaqpForm }
* @throws QSARException
* In case there are some inconsistencies with the posted parameters.
* <p>
* <table>
* <thead>
* <tr>
* <td><b>Code</b></td><td><b>Explanation</b></td>
* </tr>
* </thead>
* <tbody>
* <tr>
* <td>XQReg200</td><td>The <code>prediction_feature</code> you provided is not a valid URI</td>
* </tr>
* <tr>
* <td>XQReg201</td><td>The <code>dataset_uri</code> you provided is not a valid URI</td>
* </tr>
* </tbody>
* </table>
* </p>
* @throws NullPointerException
* In case some mandatory parameters are not specified or
* are set to null (For example if one does not specify the
* <code>prediction_feature</code>).
*/
public MLRTrainer(YaqpForm form) throws QSARException {
super(form);
}
/**
* <p>This is an auxiliary constructor for MLR Trainer, mainly from testing purposes
* and is equivalent to {@link MLRTrainer#MLRTrainer(org.opentox.www.rest.components.YaqpForm) }.
* Construct a new MLR trainer which is initialized with a set of parameters. Some
* checks are done regarding the consisteny of the parameters. The mandatory parameters
* one has to specidy are the <code>prediction_feature</code> and the <code>
* dataset_uri</code> which should be posted by the client to the training resource.
* </p>
* @param parameters
* Set of parameters as a Map<String,{@link AlgorithmParameter }>.
* @throws QSARException
* In case there are some inconsistencies with the posted parameters.
* <p>
* <table>
* <thead>
* <tr>
* <td><b>Code</b></td><td><b>Explanation</b></td>
* </tr>
* </thead>
* <tbody>
* <tr>
* <td>XQReg200</td><td>The <code>prediction_feature</code> you provided is not a valid URI</td>
* </tr>
* <tr>
* <td>XQReg201</td><td>The <code>dataset_uri</code> you provided is not a valid URI</td>
* </tr>
* </tbody>
* </table>
* </p>
* @throws NullPointerException
* In case some mandatory parameters are not specified or
* are set to null (For example if one does not specify the
* <code>prediction_feature</code>).
*/
public MLRTrainer(Map<String, AlgorithmParameter> parameters) throws QSARException {
super(parameters);
}
/**
*
* Construct a new MLRTrainer object. No parameters are specified for the trainer.
* A new <code>UUID</code> is chosen.
*/
public MLRTrainer() {
super();
}
/**
* Trains the MLR model given an Instances object with the training data. The prediction
* feature (class attributre) is specified in the constructor of the class.
* @param data The training data as <code>weka.core.Instances</code> object.
* @return The QSARModel corresponding to the trained model.
* @throws QSARException In case the model cannot be trained
* <p>
* <table>
* <thead>
* <tr>
* <td><b>Code</b></td><td><b>Explanation</b></td>
* </tr>
* </thead>
* <tbody>
* <tr>
* <td>XQReg1</td><td>Could not train the an model</td>
* </tr>
* <tr>
* <td>XQReg2</td><td>Could not generate PMML representation for the model</td>
* </tr>
* <tr>
* <td>XQReg202</td><td>The prediction feature you provided is not a valid numeric attribute of the dataset</td>
* </tr>
* </tbody>
* </table>
* </p>
* @throws NullPointerException
* In case the provided training data is null.
*/
public QSARModel train(Instances data) throws QSARException {
// GET A UUID AND DEFINE THE TEMPORARY FILE WHERE THE TRAINING DATA
// ARE STORED IN ARFF FORMAT PRIOR TO TRAINING.
final String rand = java.util.UUID.randomUUID().toString();
final String temporaryFilePath = ServerFolders.temp +"/"+rand+ ".arff";
final File tempFile = new File(temporaryFilePath);
// SAVE THE DATA IN THE TEMPORARY FILE
try {
ArffSaver dataSaver = new ArffSaver();
dataSaver.setInstances(data);
dataSaver.setDestination(new FileOutputStream(tempFile));
dataSaver.writeBatch();
}
catch (final IOException ex) {
tempFile.delete();
throw new RuntimeException("Unexpected condition while trying to save the " +
"dataset in a temporary ARFF file", ex);
}
LinearRegression linreg = new LinearRegression();
String[] linRegOptions = {"-S", "1", "-C"};
try {
linreg.setOptions(linRegOptions);
linreg.buildClassifier(data);
} catch (final Exception ex) {// illegal options or could not build the classifier!
String message = "MLR Model could not be trained";
YaqpLogger.LOG.log(new Trace(getClass(), message + " :: " + ex));
throw new QSARException(Cause.XQReg1, message, ex);
}
try {
generatePMML(linreg, data);
} catch (final YaqpIOException ex) {
String message = "Could not generate PMML representation for MLR model :: " + ex;
throw new QSARException(Cause.XQReg2, message, ex);
}
// PERFORM THE TRAINING
String[] generalOptions = {
"-c", Integer.toString(data.classIndex() + 1),
"-t", temporaryFilePath,
/// Save the model in the following directory
"-d", ServerFolders.models_weka + "/" + uuid};
try {
Evaluation.evaluateModel(linreg, generalOptions);
} catch (final Exception ex) {
tempFile.delete();
throw new QSARException(Cause.XQReg350, "Unexpected condition while trying to train "
+ "an SVM model. Possible explanation : {" + ex.getMessage() + "}", ex);
}
ArrayList<Feature> independentFeatures = new ArrayList<Feature>();
for (int i = 0; i < data.numAttributes(); i++) {
Feature f = new Feature(data.attribute(i).name());
if (data.classIndex() != i) {
independentFeatures.add(f);
}
}
Feature dependentFeature = new Feature(data.classAttribute().name());
Feature predictedFeature = dependentFeature;
QSARModel model = new QSARModel(
uuid.toString(), predictedFeature, dependentFeature,
independentFeatures, YaqpAlgorithms.MLR,
new User(), null, datasetUri, ModelStatus.UNDER_DEVELOPMENT);
model.setParams(new HashMap<String, AlgorithmParameter>());
return model;
}
/**
* Generates the PMML representation of the model and stores in the hard
* disk.
* @param coefficients The vector of the coefficients of the MLR model.
* @param model_id The id of the generated model.
* TODO: build the XML using some XML editor
*/
// <editor-fold defaultstate="collapsed" desc="PMML generation routine!">
private void generatePMML(final LinearRegression wekaModel, final Instances data) throws YaqpIOException {
final double[] coefficients = wekaModel.coefficients();
StringBuilder pmml = new StringBuilder();
pmml.append("<?xml version=\"1.0\" ?>");
pmml.append(PMMLIntro);
pmml.append("<Model ID=\"" + uuid.toString() + "\" Name=\"MLR Model\">\n");
pmml.append("<AlgorithmID href=\"" + Configuration.BASE_URI + "/algorithm/mlr\"/>\n");
pmml.append("<DatasetID href=\"" + datasetUri + "\"/>\n");
pmml.append("<AlgorithmParameters />\n");
pmml.append("<FeatureDefinitions>\n");
for (int k = 1; k <= data.numAttributes(); k++) {
pmml.append("<link href=\"" + data.attribute(k - 1).name() + "\"/>\n");
}
pmml.append("<target index=\"" + data.attribute(predictionFeature).index() + "\" name=\""
+ predictionFeature + "\"/>\n");
pmml.append("</FeatureDefinitions>\n");
pmml.append("<Timestamp>" + java.util.GregorianCalendar.getInstance().getTime() + "</Timestamp>\n");
pmml.append("</Model>\n");
pmml.append("<DataDictionary numberOfFields=\"" + data.numAttributes() + "\" >\n");
for (int k = 0; k
<= data.numAttributes() - 1; k++) {
pmml.append("<DataField name=\"" + data.attribute(k).name()
+ "\" optype=\"continuous\" dataType=\"double\" />\n");
}
pmml.append("</DataDictionary>\n");
// RegressionModel
pmml.append("<RegressionModel modelName=\"" + uuid.toString() + "\""
+ " functionName=\"regression\""
+ " modelType=\"linearRegression\""
+ " algorithmName=\"linearRegression\""
+ " targetFieldName=\"" + data.classAttribute().name() + "\""
+ ">\n");
// RegressionModel::MiningSchema
pmml.append("<MiningSchema>\n");
for (int k = 0; k <= data.numAttributes() - 1; k++) {
if (k != data.classIndex()) {
pmml.append("<MiningField name=\""
+ data.attribute(k).name() + "\" />\n");
}
}
pmml.append("<MiningField name=\""
+ data.attribute(data.classIndex()).name() + "\" "
+ "usageType=\"predicted\"/>\n");
pmml.append("</MiningSchema>\n");
// RegressionModel::RegressionTable
pmml.append("<RegressionTable intercept=\"" + coefficients[coefficients.length - 1] + "\">\n");
for (int k = 0; k
<= data.numAttributes() - 1; k++) {
if (!(predictionFeature.equals(data.attribute(k).name()))) {
pmml.append("<NumericPredictor name=\""
+ data.attribute(k).name() + "\" "
+ " exponent=\"1\" "
+ "coefficient=\"" + coefficients[k] + "\"/>\n");
}
}
pmml.append("</RegressionTable>\n");
pmml.append("</RegressionModel>\n");
pmml.append("</PMML>\n\n");
try {
FileWriter fwriter = new FileWriter(ServerFolders.models_pmml + "/" + uuid.toString());
BufferedWriter writer = new BufferedWriter(fwriter);
writer.write(pmml.toString());
writer.flush();
writer.close();
} catch (IOException ex) {
throw new YaqpIOException(Cause.XQReg3, "Could not write data to PMML file :" + uuid.toString(), ex);
}
}
// </editor-fold>
public static void main(String args[]) throws QSARException, YaqpException, URISyntaxException {
TheDbConnector.init();
final InputProcessor p1 = new InputProcessor();
final DatasetBuilder p2 = new DatasetBuilder();
final InstancesProcessor p3 = new InstancesProcessor();
final InstancesFilter p4 = new SimpleMVHFilter();
final InstancesFilter p5 = new AttributeCleanup(new ATTRIBUTE_TYPE[]{ATTRIBUTE_TYPE.string, ATTRIBUTE_TYPE.nominal});
final Map<String, AlgorithmParameter> params = new HashMap<String, AlgorithmParameter>();
params.put("prediction_feature", new AlgorithmParameter<String>(ServerList.ambit + "/feature/11954"));
params.put("dataset_uri", new AlgorithmParameter<String>("http://localhost/6"));
final WekaRegressor p6 = new MLRTrainer(params);
final Pipeline pipe = new Pipeline();
pipe.add(p1);
pipe.add(p2);
pipe.add(p3);
pipe.add(p4);
pipe.add(p5);
pipe.add(p6);
URI u = new URI("http://localhost/6");
pipe.setfailSensitive(true);
final QSARModel model = (QSARModel) pipe.process(u);
System.out.println(model.getCode());
System.out.println(model.getId());
System.out.println(pipe.getStatus());
}
}