/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.train.strategy.end; /* * Encog(tm) Core v3.3 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2014 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ import org.encog.Encog; import org.encog.ml.MLRegression; import org.encog.ml.data.MLDataSet; import org.encog.ml.train.MLTrain; import org.encog.ml.train.strategy.end.EndTrainingStrategy; import org.encog.util.obj.SerializeObject; import org.encog.util.simple.EncogUtility; import java.io.Serializable; /** * A simple early stopping strategy that halts training when the validation set no longer improves. */ public class EarlyStoppingStrategy implements EndTrainingStrategy { /** * The validation set. */ private MLDataSet validationSet; /** * The trainer. */ private MLTrain train; /** * Has training stopped. */ private boolean stop; /** * Current training error. */ private double trainingError; /** * Current validation error. */ private double lastValidationError; /** * The model that is being trained. */ private MLRegression model; /** * The frequency to check the validation set. */ private int checkFrequency; /** * How many iterations since the validation set was last checked. */ private int lastCheck; /** * The number of iterations that the validation is allowed to remain stagnant/degrading for. */ private int allowedStagnantIterations; private int stagnantIterations; /** * The best model so far. */ private MLRegression bestModel; private boolean saveBest; private double bestValidationError; private double minimumImprovement = Encog.DEFAULT_DOUBLE_EQUAL; public EarlyStoppingStrategy(MLDataSet theValidationSet) { this(theValidationSet, 5, 50); } public EarlyStoppingStrategy(MLDataSet theValidationSet, int theCheckFrequency, int theAllowedStagnantIterations) { this.validationSet = theValidationSet; this.checkFrequency = theCheckFrequency; this.allowedStagnantIterations = theAllowedStagnantIterations; } /** * {@inheritDoc} */ @Override public void init(MLTrain theTrain) { this.train = theTrain; this.model = (MLRegression) train.getMethod(); this.stop = false; this.lastCheck = 0; this.lastValidationError = Double.POSITIVE_INFINITY; } /** * {@inheritDoc} */ @Override public void preIteration() { } /** * {@inheritDoc} */ @Override public void postIteration() { this.lastCheck++; this.trainingError = this.train.getError(); if( this.lastCheck>this.checkFrequency || Double.isInfinite(this.lastValidationError) ) { double currentValidationError = EncogUtility.calculateRegressionError(this.model, this.validationSet); double improve = this.bestValidationError-currentValidationError; improve = Math.max(improve,0); if( Double.isInfinite(currentValidationError) || Double.isNaN(currentValidationError) ) { stop = true; } else if( this.bestValidationError<=currentValidationError && !Double.isInfinite(this.lastValidationError) && improve<this.minimumImprovement) { // No improvement this.stagnantIterations+=this.lastCheck; if(this.stagnantIterations>this.allowedStagnantIterations) { stop = true; } } else { // Improvement if( this.saveBest ) { this.bestModel = (MLRegression) SerializeObject.serializeClone((Serializable) this.model); } this.bestValidationError = currentValidationError; this.stagnantIterations=0; } this.lastValidationError = currentValidationError; this.lastCheck = 0; } } /** * @return Returns true if we should stop. */ @Override public boolean shouldStop() { return stop; } /** * @return the trainingError */ public double getTrainingError() { return trainingError; } /** * @return The validation error. */ public double getValidationError() { return this.lastValidationError; } public int getStagnantIterations() { return stagnantIterations; } public void setStagnantIterations(int stagnantIterations) { this.stagnantIterations = stagnantIterations; } public int getAllowedStagnantIterations() { return allowedStagnantIterations; } public void setAllowedStagnantIterations(int allowedStagnantIterations) { this.allowedStagnantIterations = allowedStagnantIterations; } public boolean isSaveBest() { return saveBest; } public void setSaveBest(boolean saveBest) { this.saveBest = saveBest; } public MLRegression getBestModel() { return bestModel; } public double getBestValidationError() { return bestValidationError; } public double getMinimumImprovement() { return minimumImprovement; } public void setMinimumImprovement(double minimumImprovement) { this.minimumImprovement = minimumImprovement; } }