/*
* Encog(tm) Core v2.5 - Java Version
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2010 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.neural.networks.training.svm;
import org.encog.Encog;
import org.encog.engine.util.ErrorCalculation;
import org.encog.mathutil.libsvm.svm;
import org.encog.mathutil.libsvm.svm_parameter;
import org.encog.mathutil.libsvm.svm_problem;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.svm.KernelType;
import org.encog.neural.networks.svm.SVMNetwork;
import org.encog.neural.networks.training.BasicTraining;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Provides training for Support Vector Machine networks.
*/
public class SVMTrain extends BasicTraining {
/**
* The logger.
*/
private static final transient Logger LOGGER = LoggerFactory
.getLogger(SVMTrain.class);
/**
* The default starting number for C.
*/
public static final double DEFAULT_CONST_BEGIN = -5;
/**
* The default ending number for C.
*/
public static final double DEFAULT_CONST_END = 15;
/**
* The default step for C.
*/
public static final double DEFAULT_CONST_STEP = 2;
/**
* The default gamma begin.
*/
public static final double DEFAULT_GAMMA_BEGIN = -10;
/**
* The default gamma end.
*/
public static final double DEFAULT_GAMMA_END = 10;
/**
* The default gamma step.
*/
public static final double DEFAULT_GAMMA_STEP = 1;
/**
* The network that is to be trained.
*/
private SVMNetwork network;
/**
* The problem to train for.
*/
private svm_problem[] problem;
/**
* The number of folds.
*/
private int fold = 5;
/**
* The beginning value for C.
*/
private double constBegin = DEFAULT_CONST_BEGIN;
/**
* The step value for C.
*/
private double constStep = DEFAULT_CONST_END;
/**
* The ending value for C.
*/
private double constEnd = DEFAULT_CONST_STEP;
/**
* The beginning value for gamma.
*/
private double gammaBegin = DEFAULT_GAMMA_BEGIN;
/**
* The ending value for gamma.
*/
private double gammaEnd = DEFAULT_GAMMA_END;
/**
* The step value for gamma.
*/
private double gammaStep = DEFAULT_GAMMA_STEP;
/**
* The best values found for C.
*/
private double[] bestConst;
/**
* The best values found for gamma.
*/
private double[] bestGamma;
/**
* The best error.
*/
private double[] bestError;
/**
* The current C.
*/
private double[] currentConst;
/**
* The current gamma.
*/
private double[] currentGamma;
/**
* Is the network setup.
*/
private boolean isSetup;
/**
* Is the training done.
*/
private boolean trainingDone;
/**
* Construct a trainer for an SVM network.
* @param network The network to train.
* @param training The training data for this network.
*/
public SVMTrain(BasicNetwork network, NeuralDataSet training) {
this.network = (SVMNetwork) network;
this.setTraining(training);
this.isSetup = false;
this.trainingDone = false;
this.problem = new svm_problem[this.network.getOutputCount()];
for (int i = 0; i < this.network.getOutputCount(); i++) {
this.problem[i] = EncodeSVMProblem.encode(training, i);
}
}
/**
* Quickly train all outputs with a C of 1.0 and a gamma equal to 1/(num inputs).
*/
public void train() {
double gamma = 1.0 / this.network.getInputCount();
double c = 1.0;
for (int i = 0; i < network.getOutputCount(); i++)
train(i, gamma, c);
}
/**
* Quickly train one output with the specified gamma and C.
* @param index The output to train.
* @param gamma The gamma to train with.
* @param c The C to train with.
*/
public void train(int index, double gamma, double c) {
network.getParams()[index].C = c;
if( gamma>Encog.DEFAULT_DOUBLE_EQUAL )
{
network.getParams()[index].gamma = 1.0 / this.network.getInputCount();
}
else
{
network.getParams()[index].gamma = gamma;
}
network.getModels()[index] = svm.svm_train(problem[index], network
.getParams()[index]);
}
/**
* Cross validate and check the specified index/gamma.
* @param index The output index to cross validate.
* @param gamma The gamma to check.
* @param c The C to check.
* @return The calculated error.
*/
public double crossValidate(int index, double gamma, double c) {
double[] target = new double[this.problem[0].l];
network.getParams()[index].C = c;
network.getParams()[index].gamma = gamma;
svm.svm_cross_validation(problem[index], network.getParams()[index], fold,
target);
return evaluate(network.getParams()[index], problem[index], target);
}
/**
* Evaluate the error for the specified model.
* @param param The params for the SVN.
* @param prob The problem to evaluate.
* @param target The output values from the SVN.
* @return The calculated error.
*/
private double evaluate(svm_parameter param, svm_problem prob,
double[] target) {
int total_correct = 0;
ErrorCalculation error = new ErrorCalculation();
if (param.svm_type == svm_parameter.EPSILON_SVR
|| param.svm_type == svm_parameter.NU_SVR) {
for (int i = 0; i < prob.l; i++) {
double ideal = prob.y[i];
double actual = target[i];
error.updateError(actual, ideal);
}
return error.calculate();
} else {
for (int i = 0; i < prob.l; i++)
if (target[i] == prob.y[i])
++total_correct;
return 100.0 * total_correct / prob.l;
}
}
/**
* Setup to train the SVM.
*/
private void setup() {
this.currentConst = new double[this.network.getOutputCount()];
this.currentGamma = new double[this.network.getOutputCount()];
this.bestConst = new double[this.network.getOutputCount()];
this.bestGamma = new double[this.network.getOutputCount()];
this.bestError = new double[this.network.getOutputCount()];
for (int i = 0; i < this.network.getOutputCount(); i++) {
this.currentConst[i] = this.constBegin;
this.currentGamma[i] = this.gammaBegin;
this.bestError[i] = Double.POSITIVE_INFINITY;
}
this.isSetup = true;
}
/**
* Perform one training iteration.
*/
public void iteration() {
if (!trainingDone) {
if (!isSetup)
setup();
preIteration();
if (network.getKernelType() == KernelType.RadialBasisFunction) {
double totalError = 0;
for (int i = 0; i < this.network.getOutputCount(); i++) {
double e = this.crossValidate(i, this.currentGamma[i],
currentConst[i]);
if (e < bestError[i]) {
this.bestConst[i] = this.currentConst[i];
this.bestGamma[i] = this.currentGamma[i];
this.bestError[i] = e;
}
this.currentConst[i] += this.constStep;
if (this.currentConst[i] > this.constEnd) {
this.currentConst[i] = this.constBegin;
this.currentGamma[i] += this.gammaStep;
if (this.currentGamma[i] > this.gammaEnd)
this.trainingDone = true;
}
totalError += this.bestError[i];
}
setError(totalError/this.network.getOutputCount());
} else {
train();
}
postIteration();
}
}
/**
* @return The problem being trained.
*/
public svm_problem[] getProblem() {
return problem;
}
/**
* @return the fold
*/
public int getFold() {
return fold;
}
/**
* @param fold
* the fold to set
*/
public void setFold(int fold) {
this.fold = fold;
}
/**
* @return the constBegin
*/
public double getConstBegin() {
return constBegin;
}
/**
* @param constBegin
* the constBegin to set
*/
public void setConstBegin(double constBegin) {
this.constBegin = constBegin;
}
/**
* @return the constStep
*/
public double getConstStep() {
return constStep;
}
/**
* @param constStep
* the constStep to set
*/
public void setConstStep(double constStep) {
this.constStep = constStep;
}
/**
* @return the constEnd
*/
public double getConstEnd() {
return constEnd;
}
/**
* @param constEnd
* the constEnd to set
*/
public void setConstEnd(double constEnd) {
this.constEnd = constEnd;
}
/**
* @return the gammaBegin
*/
public double getGammaBegin() {
return gammaBegin;
}
/**
* @param gammaBegin
* the gammaBegin to set
*/
public void setGammaBegin(double gammaBegin) {
this.gammaBegin = gammaBegin;
}
/**
* @return the gammaEnd
*/
public double getGammaEnd() {
return gammaEnd;
}
/**
* @param gammaEnd
* the gammaEnd to set
*/
public void setGammaEnd(double gammaEnd) {
this.gammaEnd = gammaEnd;
}
/**
* @return the gammaStep
*/
public double getGammaStep() {
return gammaStep;
}
/**
* @param gammaStep
* the gammaStep to set
*/
public void setGammaStep(double gammaStep) {
this.gammaStep = gammaStep;
}
/**
* Called to finish training.
*/
public void finishTraining() {
for (int i = 0; i < network.getOutputCount(); i++) {
train(i, this.bestGamma[i], this.bestConst[i]);
}
}
/**
* @return The trained network.
*/
@Override
public BasicNetwork getNetwork() {
return this.network;
}
/**
* @return True if the training is done.
*/
public boolean isTrainingDone() {
return this.trainingDone;
}
/**
* Quickly train the network with a fixed gamma and C.
* @param gamma The gamma to use.
* @param c The C to use.
*/
public void train(double gamma, double c) {
for(int i=0;i<this.network.getOutputCount();i++)
{
train(i,gamma,c);
}
}
}