/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ml.svm.training;
import org.encog.Encog;
import org.encog.EncogError;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.mathutil.libsvm.svm;
import org.encog.mathutil.libsvm.svm_parameter;
import org.encog.mathutil.libsvm.svm_problem;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.svm.SVM;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.Format;
import org.encog.util.logging.EncogLogging;
/**
* Provides training for Support Vector Machine networks.
*/
public class SVMTrain extends BasicTraining {
/**
* The network that is to be trained.
*/
private final SVM network;
/**
* The problem to train for.
*/
private svm_problem problem;
/**
* The number of folds.
*/
private int fold = 0;
/**
* Is the training done.
*/
private boolean trainingDone;
/**
* The gamma value.
*/
private double gamma;
/**
* The const c value.
*/
private double c;
/**
* Construct a trainer for an SVM network.
*
* @param method
* The network to train.
* @param dataSet
* The training data for this network.
*/
public SVMTrain(final SVM method, final MLDataSet dataSet) {
super(TrainingImplementationType.OnePass);
this.network = method;
setTraining(dataSet);
this.trainingDone = false;
this.problem = EncodeSVMProblem.encode(dataSet, 0);
this.gamma = 1.0 / this.network.getInputCount();
this.c = 1.0;
}
/**
* {@inheritDoc}
*/
@Override
public boolean canContinue() {
return false;
}
/**
* Evaluate the error for the specified model.
*
* @param param
* The params for the SVN.
* @param prob
* The problem to evaluate.
* @param target
* The output values from the SVN.
* @return The calculated error.
*/
private double evaluate(final svm_parameter param, final svm_problem prob,
final double[] target) {
int totalCorrect = 0;
final ErrorCalculation error = new ErrorCalculation();
if ((param.svm_type == svm_parameter.EPSILON_SVR)
|| (param.svm_type == svm_parameter.NU_SVR)) {
for (int i = 0; i < prob.l; i++) {
final double ideal = prob.y[i];
final double actual = target[i];
error.updateError(actual, ideal);
}
return error.calculate();
} else {
for (int i = 0; i < prob.l; i++) {
if (target[i] == prob.y[i]) {
++totalCorrect;
}
}
return Format.HUNDRED_PERCENT * totalCorrect / prob.l;
}
}
/**
* @return The constant C.
*/
public double getC() {
return this.c;
}
/**
* @return the fold
*/
public int getFold() {
return this.fold;
}
/**
* @return The gamma.
*/
public double getGamma() {
return this.gamma;
}
/**
* {@inheritDoc}
*/
@Override
public MLMethod getMethod() {
return this.network;
}
/**
* @return The problem being trained.
*/
public svm_problem getProblem() {
return this.problem;
}
/**
* @return True if the training is done.
*/
@Override
public boolean isTrainingDone() {
return this.trainingDone;
}
/**
* Perform either a train or a cross validation. If the folds property is
* greater than 1 then cross validation will be done. Cross validation does
* not produce a usable model, but it does set the error.
*
* If you are cross validating try C and Gamma values until you have a good
* error rate. Then use those values to train, producing the final model.
*/
@Override
public void iteration() {
this.network.getParams().C = this.c;
this.network.getParams().gamma = this.gamma;
EncogLogging.log(EncogLogging.LEVEL_INFO, "Training with parameters C = " + c + ", gamma = " + gamma);
if (this.fold > 1) {
// cross validate
final double[] target = new double[this.problem.l];
svm.svm_cross_validation(this.problem, this.network.getParams(),
this.fold, target);
this.network.setModel(null);
setError(evaluate(this.network.getParams(), this.problem, target));
} else {
// train
this.network.setModel(svm.svm_train(this.problem,
this.network.getParams()));
setError(this.network.calculateError(getTraining()));
}
this.trainingDone = true;
}
/**
* {@inheritDoc}
*/
@Override
public final TrainingContinuation pause() {
return null;
}
/**
* {@inheritDoc}
*/
@Override
public void resume(final TrainingContinuation state) {
}
/**
* Set the constant C.
*
* @param theC
* The constant C.
*/
public void setC(final double theC) {
this.c = theC;
if( this.c<=0 || this.c<Encog.DEFAULT_DOUBLE_EQUAL ) {
throw new EncogError("SVM training cannot use a c value less than zero.");
}
}
/**
* Set the number of folds.
*
* @param theFold
* the fold to set.
*/
public void setFold(final int theFold) {
this.fold = theFold;
}
/**
* Set the gamma.
* @param theGamma The new gamma.
*/
public void setGamma(final double theGamma) {
this.gamma = theGamma;
if( this.gamma<=0 || this.gamma<Encog.DEFAULT_DOUBLE_EQUAL ) {
throw new EncogError("SVM training cannot use a gamma value less than zero.");
}
}
}