/*
* Encog(tm) Core v2.5 - Java Version
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2010 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.neural.networks.training.strategy;
import org.encog.neural.data.Indexable;
import org.encog.neural.data.NeuralDataPair;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.Strategy;
import org.encog.neural.networks.training.Train;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Attempt to automatically set the learning rate in a learning method that
* supports a learning rate.
*
* @author jheaton
*
*/
public class SmartLearningRate implements Strategy {
/**
* Learning decay rate.
*/
public static final double LEARNING_DECAY = 0.99;
/**
* The training algorithm that is using this strategy.
*/
private Train train;
/**
* The class that is to have the learning rate set for.
*/
private LearningRate setter;
/**
* The current learning rate.
*/
private double currentLearningRate;
/**
* The training set size, this is used to pick an initial learning rate.
*/
private long trainingSize;
/**
* The error rate from the previous iteration.
*/
private double lastError;
/**
* Has one iteration passed, and we are now ready to start evaluation.
*/
private boolean ready;
/**
* The logging object.
*/
private final Logger logger = LoggerFactory.getLogger(this.getClass());
/**
* Determine the training size.
*
* @return The training size.
*/
private long determineTrainingSize() {
long result = 0;
if (this.train instanceof Indexable) {
result = ((Indexable) this).getRecordCount();
} else {
for (@SuppressWarnings("unused")
final NeuralDataPair pair : this.train.getTraining()) {
result++;
}
}
return result;
}
/**
* Initialize this strategy.
*
* @param train
* The training algorithm.
*/
public void init(final Train train) {
this.train = train;
this.ready = false;
this.setter = (LearningRate) train;
this.trainingSize = determineTrainingSize();
this.currentLearningRate = 1.0 / this.trainingSize;
if (this.logger.isInfoEnabled()) {
this.logger.info("Starting learning rate: {}",
this.currentLearningRate);
}
this.setter.setLearningRate(this.currentLearningRate);
}
/**
* Called just after a training iteration.
*/
public void postIteration() {
if (this.ready) {
if (this.train.getError() > this.lastError) {
this.currentLearningRate *= SmartLearningRate.LEARNING_DECAY;
this.setter.setLearningRate(this.currentLearningRate);
if (this.logger.isInfoEnabled()) {
this.logger.info("Adjusting learning rate to {}",
this.currentLearningRate);
}
}
} else {
this.ready = true;
}
}
/**
* Called just before a training iteration.
*/
public void preIteration() {
this.lastError = this.train.getError();
}
}