/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.train.strategy.end; /* * Encog(tm) Core v3.3 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2014 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ import org.encog.Encog; import org.encog.ml.MLRegression; import org.encog.ml.data.MLDataSet; import org.encog.ml.train.MLTrain; import org.encog.ml.train.strategy.end.EndTrainingStrategy; import org.encog.util.obj.SerializeObject; import org.encog.util.simple.EncogUtility; import java.io.Serializable; /** * A simple early stopping strategy that halts training when the training set no longer improves. */ public class StoppingStrategy implements EndTrainingStrategy { /** * The trainer. */ private MLTrain train; /** * Has training stopped. */ private boolean stop; /** * Current validation error. */ private double lastError; /** * The model that is being trained. */ private MLRegression model; /** * The number of iterations that the validation is allowed to remain stagnant/degrading for. */ private int allowedStagnantIterations; private int stagnantIterations; /** * The best model so far. */ private MLRegression bestModel; private boolean saveBest; private double bestError; private double minimumImprovement = Encog.DEFAULT_DOUBLE_EQUAL; public StoppingStrategy(MLDataSet theValidationSet) { this(50); } public StoppingStrategy(int theAllowedStagnantIterations) { this.allowedStagnantIterations = theAllowedStagnantIterations; } /** * {@inheritDoc} */ @Override public void init(MLTrain theTrain) { this.train = theTrain; this.model = (MLRegression) train.getMethod(); this.stop = false; this.lastError = Double.POSITIVE_INFINITY; } /** * {@inheritDoc} */ @Override public void preIteration() { } /** * {@inheritDoc} */ @Override public void postIteration() { double trainingError = this.train.getError(); double improve = this.bestError-trainingError; improve = Math.max(improve,0); if( Double.isInfinite(trainingError) || Double.isNaN(trainingError) ) { stop = true; } else if( this.bestError<=trainingError && !Double.isInfinite(this.lastError) && improve<this.minimumImprovement) { // No improvement this.stagnantIterations++; if(this.stagnantIterations>this.allowedStagnantIterations) { stop = true; } } else { // Improvement if( this.saveBest ) { this.bestModel = (MLRegression) SerializeObject.serializeClone((Serializable) this.model); } this.bestError = trainingError; this.stagnantIterations=0; } this.lastError = trainingError; } /** * @return Returns true if we should stop. */ @Override public boolean shouldStop() { return stop; } public int getStagnantIterations() { return stagnantIterations; } public void setStagnantIterations(int stagnantIterations) { this.stagnantIterations = stagnantIterations; } public int getAllowedStagnantIterations() { return allowedStagnantIterations; } public void setAllowedStagnantIterations(int allowedStagnantIterations) { this.allowedStagnantIterations = allowedStagnantIterations; } public boolean isSaveBest() { return saveBest; } public void setSaveBest(boolean saveBest) { this.saveBest = saveBest; } public MLRegression getBestModel() { return bestModel; } public double getMinimumImprovement() { return minimumImprovement; } public void setMinimumImprovement(double minimumImprovement) { this.minimumImprovement = minimumImprovement; } }