/*
*
* YAQP - Yet Another QSAR Project:
* Machine Learning algorithms designed for the prediction of toxicological
* features of chemical compounds become available on the Web. Yaqp is developed
* under OpenTox (http://opentox.org) which is an FP7-funded EU research project.
* This project was developed at the Automatic Control Lab in the Chemical Engineering
* School of the National Technical University of Athens. Please read README for more
* information.
*
* Copyright (C) 2009-2010 Pantelis Sopasakis & Charalampos Chomenides
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
* Contact:
* Pantelis Sopasakis
* chvng@mail.ntua.gr
* Address: Iroon Politechniou St. 9, Zografou, Athens Greece
* tel. +30 210 7723236
*/
package org.opentox.qsar.processors.trainers.classification;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import org.opentox.config.ServerFolders;
import org.opentox.core.exceptions.Cause;
import org.opentox.ontology.components.Feature;
import org.opentox.ontology.components.QSARModel;
import org.opentox.ontology.components.QSARModel.ModelStatus;
import org.opentox.ontology.util.AlgorithmParameter;
import org.opentox.ontology.util.YaqpAlgorithms;
import org.opentox.ontology.util.vocabulary.ConstantParameters;
import org.opentox.qsar.exceptions.QSARException;
import org.opentox.www.rest.components.YaqpForm;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.SMO;
import weka.classifiers.functions.supportVector.Kernel;
import weka.classifiers.functions.supportVector.PolyKernel;
import weka.classifiers.functions.supportVector.RBFKernel;
import weka.core.Instances;
import weka.core.converters.ArffSaver;
/**
*
* @author Pantelis Sopasakis
* @author Charalampos Chomenides
*/
public final class SVCTrainer extends WekaClassifier {
/**
* The parameter gamma
*/
private double gamma = Double.parseDouble(
ConstantParameters.SVMParams().get(ConstantParameters.gamma).paramValue.toString());
/**
* The cost used in the trainer's cost function
*/
private double cost = Double.parseDouble(
ConstantParameters.SVMParams().get(ConstantParameters.cost).paramValue.toString());
/**
* The bias of the kernel function of the SVM model.
*/
private double coeff0 = Double.parseDouble(
ConstantParameters.SVMParams().get(ConstantParameters.coeff0).paramValue.toString());
/**
* Maximum cache size.
*/
private int cacheSize = Integer.parseInt(
ConstantParameters.SVMParams().get(ConstantParameters.cacheSize).paramValue.toString());
/**
* Degree of a polynomial kernel
*/
private int degree = Integer.parseInt(
ConstantParameters.SVMParams().get(ConstantParameters.degree).paramValue.toString());
/**
* Convergence criterion.
*/
private double tolerance = Double.parseDouble(
ConstantParameters.SVMParams().get(ConstantParameters.tolerance).paramValue.toString());
/**
* The kernel of the SVM model.
*/
private String kernel = ConstantParameters.SVMParams().get(ConstantParameters.kernel).paramValue.toString();
/**
* Construct a new SVM Trainer.
*/
public SVCTrainer() {
super();
}
/**
* Construct a new SVM trainer given the set of parameters posted by the client as
* an instance of {@link YaqpForm }. If some of the optional parameters like <code>gamma</code>
* is not provided, its default values is assigned to it.
* @param form
* Parameters provided by the client
* @throws QSARException
* In case the training is infeasible due to inacceptable parameters provided
* by the client.
* @throws NullPointerException
* If the provided form is null.
*/
@SuppressWarnings({"unchecked"})
public SVCTrainer(final YaqpForm form) throws QSARException {
super(form);
/*
* Most of the following code is same copied from SVMTrainer
*/
// CHECK GAMMA
try {
if (form.getFirstValue(ConstantParameters.gamma) != null) {
this.gamma = Double.parseDouble(form.getFirstValue(ConstantParameters.gamma));
}
if (gamma <= 0) {
throw new QSARException(
Cause.XQReg3002, "The parameter gamma must be strictly positive. "
+ "You provided the illegal value: {" + gamma + "}");
}
} catch (final NumberFormatException ex) {
throw new QSARException(Cause.XQReg3001, "Parameter gamma should be numeric. "
+ "You provided the illegal value : {" + form.getFirstValue(ConstantParameters.gamma) + "}", ex);
}
putParameter(ConstantParameters.gamma, new AlgorithmParameter((double) gamma));
// CHECK COST
try {
if (form.getFirstValue(ConstantParameters.cost) != null) {
this.cost = Double.parseDouble(form.getFirstValue(ConstantParameters.cost));
}
if (cost <= 0) {
throw new QSARException(
Cause.XQReg3004, "The parameter " + ConstantParameters.cost + " must be strictly positive. "
+ "You provided the illegal value: {" + cost + "}");
}
} catch (final NumberFormatException ex) {
throw new QSARException(Cause.XQReg3003, "Parameter " + ConstantParameters.cost + " should be numeric. "
+ "You provided the illegal "
+ "value : {" + form.getFirstValue(ConstantParameters.cost) + "}", ex);
}
putParameter(ConstantParameters.cost, new AlgorithmParameter(cost));
// CHECK COEFF_0
try {
if (form.getFirstValue(ConstantParameters.coeff0) != null) {
this.coeff0 = Double.parseDouble(form.getFirstValue(ConstantParameters.coeff0));
}
} catch (final NumberFormatException ex) {
throw new QSARException(Cause.XQReg3007, "Parameter " + ConstantParameters.coeff0 + " should be numeric. "
+ "You provided the illegal "
+ "value : {" + form.getFirstValue(ConstantParameters.coeff0) + "}", ex);
}
putParameter(ConstantParameters.coeff0, new AlgorithmParameter(coeff0));
// CHECK CACHE SIZE
try {
if (form.getFirstValue(ConstantParameters.cacheSize) != null) {
this.cacheSize = Integer.parseInt(form.getFirstValue(ConstantParameters.cacheSize));
}
} catch (final NumberFormatException ex) {
throw new QSARException(Cause.XQReg3008, "Parameter " + ConstantParameters.cacheSize + " should be integer. "
+ "You provided the illegal "
+ "value : {" + form.getFirstValue(ConstantParameters.cacheSize) + "}", ex);
}
putParameter(ConstantParameters.cacheSize, new AlgorithmParameter(cacheSize));
// CHECK DEGREE
try {
if (form.getFirstValue(ConstantParameters.degree) != null) {
this.degree = Integer.parseInt(form.getFirstValue(ConstantParameters.degree));
}
} catch (final NumberFormatException ex) {
throw new QSARException(Cause.XQReg3009, "Parameter " + ConstantParameters.degree + " should be integer. "
+ "You provided the illegal "
+ "value : {" + form.getFirstValue(ConstantParameters.degree) + "}", ex);
}
putParameter(ConstantParameters.degree, new AlgorithmParameter(degree));
// CHECK TOLERANCE
try {
if (form.getFirstValue(ConstantParameters.tolerance) != null) {
this.tolerance = Double.parseDouble(form.getFirstValue(ConstantParameters.tolerance));
}
if (tolerance < 1E-6) {
throw new QSARException(
Cause.XQReg3011, "The parameter " + ConstantParameters.tolerance + " must be greater that 1E-6. "
+ "You provided the illegal value: {"
+ tolerance + "}");
}
} catch (final NumberFormatException ex) {
throw new QSARException(Cause.XQReg3010, "Parameter " + ConstantParameters.tolerance + " should be numeric. "
+ "You provided the illegal value : {" + form.getFirstValue(ConstantParameters.tolerance) + "}", ex);
}
putParameter(ConstantParameters.tolerance, new AlgorithmParameter(tolerance));
// CHECK KERNEL
if (form.getFirstValue(ConstantParameters.kernel) != null) {
this.kernel = form.getFirstValue(ConstantParameters.kernel).toUpperCase();
if (!kernel.equals("RBF") && !kernel.equals("LINEAR") && !kernel.equals("POLYNOMIAL")) {
throw new QSARException(Cause.XQReg3012, "The available kernels are [RBF; LINEAR; POLYNOMIAL]. Note that "
+ "this parameter is not case-sensitive, i.e. rbf is the same as RbF. However you provided "
+ "the illegal value : {" + kernel + "}");
}
}
putParameter(ConstantParameters.kernel, new AlgorithmParameter(kernel));
}/* End of constructor */
public QSARModel train(Instances data) throws QSARException {
// GET A UUID AND DEFINE THE TEMPORARY FILE WHERE THE TRAINING DATA
// ARE STORED IN ARFF FORMAT PRIOR TO TRAINING.
final String rand = java.util.UUID.randomUUID().toString();
final String temporaryFilePath = ServerFolders.temp + "/" + rand + ".arff";
final File tempFile = new File(temporaryFilePath);
// SAVE THE DATA IN THE TEMPORARY FILE
try {
ArffSaver dataSaver = new ArffSaver();
dataSaver.setInstances(data);
dataSaver.setDestination(new FileOutputStream(tempFile));
dataSaver.writeBatch();
if (!tempFile.exists()) {
throw new IOException("Temporary File was not created");
}
} catch (final IOException ex) {/*
* The content of the dataset cannot be
* written to the destination file due to
* some communication issue.
*/
tempFile.delete();
throw new RuntimeException("Unexpected condition while trying to save the "
+ "dataset in a temporary ARFF file", ex);
}
// INITIALIZE THE CLASSIFIER
SMO classifier = new SMO();
classifier.setEpsilon(0.1);
classifier.setToleranceParameter(tolerance);
// CONSTRUCT A KERNEL ACCORDING TO THE POSTED PARAMETERS
// SUPPORTED KERNELS ARE {rbf, linear, polynomial}
Kernel svc_kernel = null;
if (this.kernel.equalsIgnoreCase("rbf")) {
RBFKernel rbf_kernel = new RBFKernel();
rbf_kernel.setGamma(gamma);
rbf_kernel.setCacheSize(cacheSize);
svc_kernel = rbf_kernel;
} else if (this.kernel.equalsIgnoreCase("polynomial")) {
PolyKernel poly_kernel = new PolyKernel();
poly_kernel.setExponent(degree);
poly_kernel.setCacheSize(cacheSize);
poly_kernel.setUseLowerOrder(true);
svc_kernel = poly_kernel;
} else if (this.kernel.equalsIgnoreCase("linear")) {
PolyKernel linear_kernel = new PolyKernel();
linear_kernel.setExponent((double) 1.0);
linear_kernel.setCacheSize(cacheSize);
linear_kernel.setUseLowerOrder(true);
svc_kernel = linear_kernel;
}
classifier.setKernel(svc_kernel);
String modelFilePath = ServerFolders.models_weka + "/" + uuid.toString();
String[] generalOptions = {
"-c", Integer.toString(data.classIndex() + 1),
"-t", temporaryFilePath,
/// Save the model in the following directory
"-d", modelFilePath};
// AFTER ALL, BUILD THE CLASSIFICATION MODEL AND SAVE IT AS A SERIALIZED
// WEKA FILE IN THE CORRESPONDING DIRECTORY.
try {
Evaluation.evaluateModel(classifier, generalOptions);
} catch (final Exception ex) {
tempFile.delete();
throw new QSARException(Cause.XQReg350, "Unexpected condition while trying to train "
+ "a support vector classification model. Possible explanation : {" + ex.getMessage() + "}", ex);
}
ArrayList<Feature> independentFeatures = new ArrayList<Feature>();
for (int i = 0; i < data.numAttributes(); i++) {
Feature f = new Feature(data.attribute(i).name());
if (data.classIndex() != i) {
independentFeatures.add(f);
}
}
Feature dependentFeature = new Feature(data.classAttribute().name());
Feature predictedFeature = dependentFeature;
QSARModel model = new QSARModel();
model.setCode(uuid.toString());
model.setAlgorithm(YaqpAlgorithms.SVC);
model.setPredictionFeature(predictedFeature);
model.setDependentFeature(dependentFeature);
model.setIndependentFeatures(independentFeatures);
model.setDataset(datasetUri);
model.setParams(getParameters());
model.setModelStatus(ModelStatus.UNDER_DEVELOPMENT);
tempFile.delete();
return model;
}
}