/*
* Encog(tm) Core v2.5 - Java Version
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2010 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.neural.networks.training.propagation.resilient;
import org.encog.engine.network.train.prop.OpenCLTrainingProfile;
import org.encog.engine.network.train.prop.RPROPConst;
import org.encog.engine.network.train.prop.TrainFlatNetworkOpenCL;
import org.encog.engine.network.train.prop.TrainFlatNetworkResilient;
import org.encog.engine.util.EngineArray;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
/**
* One problem with the backpropagation algorithm is that the magnitude of the
* partial derivative is usually too large or too small. Further, the learning
* rate is a single value for the entire neural network. The resilient
* propagation learning algorithm uses a special update value(similar to the
* learning rate) for every neuron connection. Further these update values are
* automatically determined, unlike the learning rate of the backpropagation
* algorithm.
*
* For most training situations, we suggest that the resilient propagation
* algorithm (this class) be used for training.
*
* There are a total of three parameters that must be provided to the resilient
* training algorithm. Defaults are provided for each, and in nearly all cases,
* these defaults are acceptable. This makes the resilient propagation algorithm
* one of the easiest and most efficient training algorithms available.
*
* The optional parameters are:
*
* zeroTolerance - How close to zero can a number be to be considered zero. The
* default is 0.00000000000000001.
*
* initialUpdate - What are the initial update values for each matrix value. The
* default is 0.1.
*
* maxStep - What is the largest amount that the update values can step. The
* default is 50.
*
*
* Usually you will not need to use these, and you should use the constructor
* that does not require them.
*
*
* @author jheaton
*
*/
public class ResilientPropagation extends Propagation {
/**
* Continuation tag for the last gradients.
*/
public static final String LAST_GRADIENTS = "LAST_GRADIENTS";
/**
* Continuation tag for the last values.
*/
public static final String UPDATE_VALUES = "UPDATE_VALUES";
/**
* Construct a resilient training object. Use the defaults for all training
* parameters. Usually this is the constructor to use as the resilient
* training algorithm is designed for the default parameters to be
* acceptable for nearly all problems. Use the CPU to train.
*
* @param network
* The network to train.
* @param training
* The training set to use.
*/
public ResilientPropagation(final BasicNetwork network,
final NeuralDataSet training) {
this(network, training, null, RPROPConst.DEFAULT_INITIAL_UPDATE,
RPROPConst.DEFAULT_MAX_STEP);
}
/**
* Construct an RPROP trainer, allows an OpenCL device to be specified. Use
* the defaults for all training parameters. Usually this is the constructor
* to use as the resilient training algorithm is designed for the default
* parameters to be acceptable for nearly all problems.
*
* @param network
* The network to train.
* @param training
* The training data to use.
* @param profile
* The profile to use.
*/
public ResilientPropagation(final BasicNetwork network,
final NeuralDataSet training, final OpenCLTrainingProfile profile) {
this(network, training, profile, RPROPConst.DEFAULT_INITIAL_UPDATE,
RPROPConst.DEFAULT_MAX_STEP);
}
/**
* Construct a resilient training object, allow the training parameters to
* be specified. Usually the default parameters are acceptable for the
* resilient training algorithm. Therefore you should usually use the other
* constructor, that makes use of the default values.
*
* @param network
* The network to train.
* @param training
* The training set to use.
* @param profile
* Optional EncogCL profile to execute on.
* @param initialUpdate
* The initial update values, this is the amount that the deltas
* are all initially set to.
* @param maxStep
* The maximum that a delta can reach.
*/
public ResilientPropagation(final BasicNetwork network,
final NeuralDataSet training, final OpenCLTrainingProfile profile,
final double initialUpdate, final double maxStep) {
super(network, training);
if (profile == null) {
TrainFlatNetworkResilient rpropFlat = new TrainFlatNetworkResilient(
network.getStructure().getFlat(), this.getTraining());
this.setFlatTraining(rpropFlat);
} else {
TrainFlatNetworkOpenCL rpropFlat = new TrainFlatNetworkOpenCL(
network.getStructure().getFlat(), this.getTraining(),
profile);
rpropFlat.learnRPROP(initialUpdate, maxStep);
this.setFlatTraining(rpropFlat);
}
}
/**
* @return True, as RPROP can continue.
*/
public boolean canContinue() {
return true;
}
/**
* Determine if the specified continuation object is valid to resume with.
*
* @param state
* The continuation object to check.
* @return True if the specified continuation object is valid for this
* training method and network.
*/
public boolean isValidResume(final TrainingContinuation state) {
if (!state.getContents().containsKey(
ResilientPropagation.LAST_GRADIENTS)
|| !state.getContents().containsKey(
ResilientPropagation.UPDATE_VALUES)) {
return false;
}
final double[] d = (double[]) state
.get(ResilientPropagation.LAST_GRADIENTS);
return d.length == getNetwork().getStructure().calculateSize();
}
/**
* Pause the training.
*
* @return A training continuation object to continue with.
*/
public TrainingContinuation pause() {
final TrainingContinuation result = new TrainingContinuation();
if (this.getFlatTraining() instanceof TrainFlatNetworkResilient) {
result.set(ResilientPropagation.LAST_GRADIENTS,
((TrainFlatNetworkResilient) this.getFlatTraining())
.getLastGradient());
result.set(ResilientPropagation.UPDATE_VALUES,
((TrainFlatNetworkResilient) this.getFlatTraining())
.getUpdateValues());
} else {
result.set(ResilientPropagation.LAST_GRADIENTS,
((TrainFlatNetworkOpenCL) this.getFlatTraining())
.getLastGradient());
result.set(ResilientPropagation.UPDATE_VALUES,
((TrainFlatNetworkOpenCL) this.getFlatTraining())
.getUpdateValues());
}
return result;
}
/**
* Resume training.
*
* @param state
* The training state to return to.
*/
public void resume(final TrainingContinuation state) {
if (!isValidResume(state)) {
throw new TrainingError("Invalid training resume data length");
}
double[] lastGradient = (double[]) state
.get(ResilientPropagation.LAST_GRADIENTS);
double[] updateValues = (double[]) state
.get(ResilientPropagation.UPDATE_VALUES);
if (this.getFlatTraining() instanceof TrainFlatNetworkResilient) {
EngineArray.arrayCopy(lastGradient,
((TrainFlatNetworkResilient) this.getFlatTraining())
.getLastGradient());
EngineArray.arrayCopy(updateValues,
((TrainFlatNetworkResilient) this.getFlatTraining())
.getUpdateValues());
} else if (this.getFlatTraining() instanceof TrainFlatNetworkOpenCL) {
EngineArray.arrayCopy(lastGradient, ((TrainFlatNetworkOpenCL) this
.getFlatTraining()).getLastGradient());
EngineArray.arrayCopy(updateValues, ((TrainFlatNetworkOpenCL) this
.getFlatTraining()).getUpdateValues());
}
}
}