/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.svm.training; import org.encog.Encog; import org.encog.EncogError; import org.encog.ml.MLMethod; import org.encog.ml.TrainingImplementationType; import org.encog.ml.data.MLDataSet; import org.encog.ml.svm.KernelType; import org.encog.ml.svm.SVM; import org.encog.ml.train.BasicTraining; import org.encog.neural.networks.training.propagation.TrainingContinuation; /** * Provides training for Support Vector Machine networks. */ public class SVMSearchTrain extends BasicTraining { /** * The default starting number for C. */ public static final double DEFAULT_CONST_BEGIN = 1; /** * The default ending number for C. */ public static final double DEFAULT_CONST_END = 15; /** * The default step for C. */ public static final double DEFAULT_CONST_STEP = 2; /** * The default gamma begin. */ public static final double DEFAULT_GAMMA_BEGIN = 1; /** * The default gamma end. */ public static final double DEFAULT_GAMMA_END = 10; /** * The default gamma step. */ public static final double DEFAULT_GAMMA_STEP = 1; /** * The network that is to be trained. */ private final SVM network; /** * The number of folds. */ private int fold = 0; /** * The beginning value for C. */ private double constBegin = SVMSearchTrain.DEFAULT_CONST_BEGIN; /** * The step value for C. */ private double constStep = SVMSearchTrain.DEFAULT_CONST_STEP; /** * The ending value for C. */ private double constEnd = SVMSearchTrain.DEFAULT_CONST_END; /** * The beginning value for gamma. */ private double gammaBegin = SVMSearchTrain.DEFAULT_GAMMA_BEGIN; /** * The ending value for gamma. */ private double gammaEnd = SVMSearchTrain.DEFAULT_GAMMA_END; /** * The step value for gamma. */ private double gammaStep = SVMSearchTrain.DEFAULT_GAMMA_STEP; /** * The best values found for C. */ private double bestConst; /** * The best values found for gamma. */ private double bestGamma; /** * The best error. */ private double bestError; /** * The current C. */ private double currentConst; /** * The current gamma. */ private double currentGamma; /** * Is the network setup. */ private boolean isSetup; /** * Is the training done. */ private boolean trainingDone; /** * The internal training object, used for the search. */ private final SVMTrain internalTrain; /** * Construct a trainer for an SVM network. * * @param method * The method to train. * @param training * The training data for this network. */ public SVMSearchTrain(final SVM method, final MLDataSet training) { super(TrainingImplementationType.Iterative); this.network = method; setTraining(training); this.isSetup = false; this.trainingDone = false; this.internalTrain = new SVMTrain(network, training); } /** * {@inheritDoc} */ @Override public boolean canContinue() { return false; } /** * {@inheritDoc} */ @Override public void finishTraining() { this.internalTrain.setGamma(this.bestGamma); this.internalTrain.setC(this.bestConst); this.internalTrain.iteration(); } /** * @return the constBegin */ public double getConstBegin() { return this.constBegin; } /** * @return the constEnd */ public double getConstEnd() { return this.constEnd; } /** * @return the constStep */ public double getConstStep() { return this.constStep; } /** * @return the fold */ public int getFold() { return this.fold; } /** * @return the gammaBegin */ public double getGammaBegin() { return this.gammaBegin; } /** * @return the gammaEnd */ public double getGammaEnd() { return this.gammaEnd; } /** * @return the gammaStep */ public double getGammaStep() { return this.gammaStep; } /** * {@inheritDoc} */ @Override public MLMethod getMethod() { return this.network; } /** * @return True if the training is done. */ @Override public boolean isTrainingDone() { return this.trainingDone; } /** * Perform one training iteration. */ @Override public void iteration() { if (!this.trainingDone) { if (!this.isSetup) { setup(); } preIteration(); this.internalTrain.setFold(this.fold); if (this.network.getKernelType() == KernelType.RadialBasisFunction) { this.internalTrain.setGamma(this.currentGamma); this.internalTrain.setC(this.currentConst); double e = 0; this.internalTrain.iteration(); e = this.internalTrain.getError(); //System.out.println(this.currentGamma + "," + this.currentConst // + "," + e); // new best error? if (!Double.isNaN(e)) { if (e < this.bestError) { this.bestConst = this.currentConst; this.bestGamma = this.currentGamma; this.bestError = e; } } // advance this.currentConst += this.constStep; if (this.currentConst > this.constEnd) { this.currentConst = this.constBegin; this.currentGamma += this.gammaStep; if (this.currentGamma > this.gammaEnd) { this.trainingDone = true; } } setError(this.bestError); } else { this.internalTrain.setGamma(this.currentGamma); this.internalTrain.setC(this.currentConst); this.internalTrain.iteration(); } postIteration(); } } /** * {@inheritDoc} */ @Override public TrainingContinuation pause() { return null; } /** * {@inheritDoc} */ @Override public void resume(final TrainingContinuation state) { } /** * @param theConstBegin * the constBegin to set */ public void setConstBegin(final double theConstBegin) { this.constBegin = theConstBegin; } /** * @param theConstEnd * the constEnd to set */ public void setConstEnd(final double theConstEnd) { this.constEnd = theConstEnd; } /** * @param theConstStep * the constStep to set */ public void setConstStep(final double theConstStep) { this.constStep = theConstStep; } /** * @param theFold * the fold to set */ public void setFold(final int theFold) { this.fold = theFold; } /** * @param theGammaBegin * the gammaBegin to set */ public void setGammaBegin(final double theGammaBegin) { this.gammaBegin = theGammaBegin; } /** * @param theGammaEnd * the gammaEnd to set. */ public final void setGammaEnd(final double theGammaEnd) { this.gammaEnd = theGammaEnd; } /** * @param theGammaStep * the gammaStep to set */ public final void setGammaStep(final double theGammaStep) { this.gammaStep = theGammaStep; } /** * Setup to train the SVM. */ private void setup() { this.currentConst = this.constBegin; this.currentGamma = this.gammaBegin; this.bestError = Double.POSITIVE_INFINITY; this.isSetup = true; if( this.currentGamma<=0 || this.currentGamma<Encog.DEFAULT_DOUBLE_EQUAL ) { throw new EncogError("SVM search training cannot use a gamma value less than zero."); } if( this.currentConst<=0 || this.currentConst<Encog.DEFAULT_DOUBLE_EQUAL ) { throw new EncogError("SVM search training cannot use a const value less than zero."); } if( this.gammaStep<0 ) { throw new EncogError("SVM search gamma step cannot use a const value less than zero."); } if( this.constStep<0 ) { throw new EncogError("SVM search const step cannot use a const value less than zero."); } } /** * @return the bestConst */ public double getBestConst() { return bestConst; } /** * @param bestConst the bestConst to set */ public void setBestConst(double bestConst) { this.bestConst = bestConst; } /** * @return the bestGamma */ public double getBestGamma() { return bestGamma; } /** * @param bestGamma the bestGamma to set */ public void setBestGamma(double bestGamma) { this.bestGamma = bestGamma; } }