/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ml.train.strategy;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.end.EndTrainingStrategy;
/**
* This strategy will indicate once training is no longer improving the neural
* network by a specified amount, over a specified number of cycles. This allows
* the program to automatically determine when to stop training.
*
* @author jheaton
*
*/
public class StopTrainingStrategy implements EndTrainingStrategy {
/**
* The default minimum improvement before training stops.
*/
public static final double DEFAULT_MIN_IMPROVEMENT = 0.0000001;
/**
* The default number of cycles to tolerate.
*/
public static final int DEFAULT_TOLERATE_CYCLES = 100;
/**
* The training algorithm that is using this strategy.
*/
private MLTrain train;
/**
* Flag to indicate if training should stop.
*/
private boolean shouldStop;
/**
* Has one iteration passed, and we are now ready to start evaluation.
*/
private boolean ready;
/**
* The error rate from the previous iteration.
*/
private double lastError;
/**
* The error rate from the previous iteration.
*/
private double bestError;
/**
* The minimum improvement before training stops.
*/
private final double minImprovement;
/**
* The number of cycles to tolerate the minimum improvement.
*/
private final int toleratedCycles;
/**
* The number of bad training cycles.
*/
private int badCycles;
/**
* Construct the strategy with default options.
*/
public StopTrainingStrategy() {
this(StopTrainingStrategy.DEFAULT_MIN_IMPROVEMENT,
StopTrainingStrategy.DEFAULT_TOLERATE_CYCLES);
}
/**
* Construct the strategy with the specified parameters.
* @param minImprovement The minimum accepted improvement.
* @param toleratedCycles The number of cycles to tolerate before stopping.
*/
public StopTrainingStrategy(final double minImprovement,
final int toleratedCycles) {
this.minImprovement = minImprovement;
this.toleratedCycles = toleratedCycles;
this.badCycles = 0;
this.bestError = Double.MAX_VALUE;
}
/**
* {@inheritDoc}
*/
public void init(final MLTrain train) {
this.train = train;
this.shouldStop = false;
this.ready = false;
}
/**
* {@inheritDoc}
*/
public void postIteration() {
if (this.ready) {
if (Math.abs(this.bestError
- this.train.getError()) < this.minImprovement) {
this.badCycles++;
if (this.badCycles > this.toleratedCycles) {
this.shouldStop = true;
}
} else {
this.badCycles = 0;
}
} else {
this.ready = true;
}
this.lastError = this.train.getError();
this.bestError = Math.min(this.lastError, this.bestError);
}
/**
* {@inheritDoc}
*/
public void preIteration() {
}
/**
* {@inheritDoc}
*/
public boolean shouldStop() {
return this.shouldStop;
}
}