/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.neural.networks.training.propagation.quick;
import java.util.Random;
import org.encog.EncogError;
import org.encog.ml.data.MLDataSet;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.EngineArray;
import org.encog.util.validate.ValidateNetwork;
/**
* QPROP is an efficient training method that is based on Newton's Method.
* QPROP was introduced in a paper:
*
* An Empirical Study of Learning Speed in Back-Propagation Networks" (Scott E. Fahlman, 1988)
*
*
* http://www.heatonresearch.com/wiki/Quickprop
*
*/
public class QuickPropagation extends Propagation implements
LearningRate {
/**
* Continuation tag for the last gradients.
*/
public static final String LAST_GRADIENTS = "LAST_GRADIENTS";
/**
* The learning rate.
*/
private double learningRate;
/**
* The last delta values.
*/
private double[] lastDelta;
/**
* This factor times the current weight is added to the slope
* at the start of each output epoch. Keeps weights from growing
* too big.
*/
private double decay = 0.0001d;
/**
* Used to scale for the size of the training set.
*/
private double eps;
/**
* Controls the amount of linear gradient descent
* to use in updating output weights.
*/
private double outputEpsilon = 0.35;
/**
* Used in computing whether the proposed step is
* too large. Related to learningRate.
*/
private double shrink;
/**
* Construct a QPROP trainer for flat networks. Uses a learning rate of 2.
*
* @param network
* The network to train.
* @param training
* The training data.
*/
public QuickPropagation(final ContainsFlat network, final MLDataSet training) {
this(network, training, 2.0);
}
/**
* Construct a QPROP trainer for flat networks.
*
* @param network
* The network to train.
* @param training
* The training data.
* @param theLearningRate
* The learning rate. 2 is a good suggestion as
* a learning rate to start with. If it fails to converge,
* then drop it. Just like backprop, except QPROP can
* take higher learning rates.
*/
public QuickPropagation(final ContainsFlat network,
final MLDataSet training, final double theLearningRate) {
super(network, training);
ValidateNetwork.validateMethodToData(network, training);
this.learningRate = theLearningRate;
this.lastDelta = new double[this.network.getFlat().getWeights().length];
}
/**
* {@inheritDoc}
*/
@Override
public boolean canContinue() {
return false;
}
/**
* @return The last delta values.
*/
public double[] getLastDelta() {
return this.lastDelta;
}
/**
* @return The learning rate, this is value is essentially a percent. It is
* the degree to which the gradients are applied to the weight
* matrix to allow learning.
*/
@Override
public double getLearningRate() {
return this.learningRate;
}
/**
* Determine if the specified continuation object is valid to resume with.
*
* @param state
* The continuation object to check.
* @return True if the specified continuation object is valid for this
* training method and network.
*/
public boolean isValidResume(final TrainingContinuation state) {
if (!state.getContents().containsKey(QuickPropagation.LAST_GRADIENTS)) {
return false;
}
if (!state.getTrainingType().equals(getClass().getSimpleName())) {
return false;
}
final double[] d = (double[]) state.get(QuickPropagation.LAST_GRADIENTS);
return d.length == ((ContainsFlat) getMethod()).getFlat().getWeights().length;
}
/**
* Pause the training.
*
* @return A training continuation object to continue with.
*/
@Override
public TrainingContinuation pause() {
final TrainingContinuation result = new TrainingContinuation();
result.setTrainingType(this.getClass().getSimpleName());
result.set(QuickPropagation.LAST_GRADIENTS, this.getLastGradient());
return result;
}
/**
* Resume training.
*
* @param state
* The training state to return to.
*/
@Override
public void resume(final TrainingContinuation state) {
if (!isValidResume(state)) {
throw new TrainingError("Invalid training resume data length");
}
final double[] lastGradient = (double[]) state
.get(QuickPropagation.LAST_GRADIENTS);
EngineArray.arrayCopy(lastGradient,this.getLastGradient());
}
/**
* Set the learning rate, this is value is essentially a percent. It is the
* degree to which the gradients are applied to the weight matrix to allow
* learning.
*
* @param rate
* The learning rate.
*/
@Override
public void setLearningRate(final double rate) {
this.learningRate = rate;
}
/**
* @return the outputEpsilon
*/
public double getOutputEpsilon() {
return this.outputEpsilon;
}
/**
* @return the shrink
*/
public double getShrink() {
return this.shrink;
}
/**
* @param s the shrink to set
*/
public void setShrink(double s) {
this.shrink = s;
}
/**
* @param theOutputEpsilon the outputEpsilon to set
*/
public void setOutputEpsilon(double theOutputEpsilon) {
this.outputEpsilon = theOutputEpsilon;
}
/**
* Perform training method specific init.
*/
public void initOthers() {
this.eps = this.outputEpsilon / getTraining().getRecordCount();
this.shrink = this.learningRate / (1.0 + this.learningRate);
}
/**
* Update a weight.
*
* @param gradients
* The gradients.
* @param lastGradient
* The last gradients.
* @param index
* The index.
* @return The weight delta.
*/
@Override
public double updateWeight(final double[] gradients,
final double[] lastGradient, final int index) {
final double w = this.network.getFlat().getWeights()[index];
final double d = this.lastDelta[index];
final double s = -this.gradients[index] + this.decay * w;
final double p = -lastGradient[index];
double nextStep = 0.0;
// The step must always be in direction opposite to the slope.
if (d < 0.0) {
// If last step was negative...
if (s > 0.0) {
// Add in linear term if current slope is still positive.
nextStep -= this.eps * s;
}
// If current slope is close to or larger than prev slope...
if (s >= (this.shrink * p)) {
// Take maximum size negative step.
nextStep += this.learningRate * d;
} else {
// Else, use quadratic estimate.
nextStep += d * s / (p - s);
}
} else if (d > 0.0) {
// If last step was positive...
if (s < 0.0) {
// Add in linear term if current slope is still negative.
nextStep -= this.eps * s;
}
// If current slope is close to or more neg than prev slope...
if (s <= (this.shrink * p)) {
// Take maximum size negative step.
nextStep += this.learningRate * d;
} else {
// Else, use quadratic estimate.
nextStep += d * s / (p - s);
}
} else {
// Last step was zero, so use only linear term.
nextStep -= this.eps * s;
}
// update global data arrays
this.lastDelta[index] = nextStep;
this.getLastGradient()[index] = gradients[index];
return nextStep;
}
/**
* Update a weight with droput.
*
* @param gradients
* The gradients.
* @param lastGradient
* The last gradients.
* @param index
* The index.
* @param dropoutRate
* The dropout rate.
* @return The weight delta.
*/
@Override
public double updateWeight(final double[] gradients,
final double[] lastGradient, final int index, double dropoutRate) {
if (dropoutRate > 0 && dropoutRandomSource.nextDouble() < dropoutRate) {
return 0;
};
final double w = this.network.getFlat().getWeights()[index];
final double d = this.lastDelta[index];
final double s = -this.gradients[index] + this.decay * w;
final double p = -lastGradient[index];
double nextStep = 0.0;
// The step must always be in direction opposite to the slope.
if (d < 0.0) {
// If last step was negative...
if (s > 0.0) {
// Add in linear term if current slope is still positive.
nextStep -= this.eps * s;
}
// If current slope is close to or larger than prev slope...
if (s >= (this.shrink * p)) {
// Take maximum size negative step.
nextStep += this.learningRate * d;
} else {
// Else, use quadratic estimate.
nextStep += d * s / (p - s);
}
} else if (d > 0.0) {
// If last step was positive...
if (s < 0.0) {
// Add in linear term if current slope is still negative.
nextStep -= this.eps * s;
}
// If current slope is close to or more neg than prev slope...
if (s <= (this.shrink * p)) {
// Take maximum size negative step.
nextStep += this.learningRate * d;
} else {
// Else, use quadratic estimate.
nextStep += d * s / (p - s);
}
} else {
// Last step was zero, so use only linear term.
nextStep -= this.eps * s;
}
// update global data arrays
this.lastDelta[index] = nextStep;
this.getLastGradient()[index] = gradients[index];
return nextStep;
}
/**
* Do not allow batch sizes other than 0, not supported.
*/
public void setBatchSize(int theBatchSize) {
if( theBatchSize!=0 ) {
throw new EncogError("Online training is not supported for:" + this.getClass().getSimpleName());
}
}
}