/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.neural.networks.training.propagation.back; import org.encog.ml.data.MLDataSet; import org.encog.neural.networks.ContainsFlat; import org.encog.neural.networks.training.LearningRate; import org.encog.neural.networks.training.Momentum; import org.encog.neural.networks.training.TrainingError; import org.encog.neural.networks.training.propagation.Propagation; import org.encog.neural.networks.training.propagation.TrainingContinuation; import org.encog.neural.networks.training.strategy.SmartLearningRate; import org.encog.neural.networks.training.strategy.SmartMomentum; import org.encog.util.validate.ValidateNetwork; /** * This class implements a backpropagation training algorithm for feed forward * neural networks. It is used in the same manner as any other training class * that implements the Train interface. * * Backpropagation is a common neural network training algorithm. It works by * analyzing the error of the output of the neural network. Each neuron in the * output layer's contribution, according to weight, to this error is * determined. These weights are then adjusted to minimize this error. This * process continues working its way backwards through the layers of the neural * network. * * This implementation of the backpropagation algorithm uses both momentum and a * learning rate. The learning rate specifies the degree to which the weight * matrixes will be modified through each iteration. The momentum specifies how * much the previous learning iteration affects the current. To use no momentum * at all specify zero. * * One primary problem with backpropagation is that the magnitude of the partial * derivative is often detrimental to the training of the neural network. The * other propagation methods of Manhatten and Resilient address this issue in * different ways. In general, it is suggested that you use the resilient * propagation technique for most Encog training tasks over back propagation. */ public class Backpropagation extends Propagation implements Momentum, LearningRate { /** * The resume key for backpropagation. */ public static final String LAST_DELTA = "LAST_DELTA"; /** * The learning rate. */ private double learningRate; /** * The momentum. */ private double momentum; /** * The last delta values. */ private double[] lastDelta; /** * Should Nesterov momentum be used? */ private boolean nesterovUpdate; /** * Create a class to train using backpropagation. Use auto learn rate and * momentum. Use the CPU to train. * * @param network * The network that is to be trained. * @param training * The training data to be used for backpropagation. */ public Backpropagation(final ContainsFlat network, final MLDataSet training) { this(network, training, 0, 0); addStrategy(new SmartLearningRate()); addStrategy(new SmartMomentum()); } /** * * @param network * The network that is to be trained * @param training * The training set * @param theLearnRate * The rate at which the weight matrix will be adjusted based on * learning. * @param theMomentum * The influence that previous iteration's training deltas will * have on the current iteration. */ public Backpropagation(final ContainsFlat network, final MLDataSet training, final double theLearnRate, final double theMomentum) { super(network, training); ValidateNetwork.validateMethodToData(network, training); this.momentum = theMomentum; this.learningRate = theLearnRate; this.lastDelta = new double[network.getFlat().getWeights().length]; } /** * {@inheritDoc} */ @Override public boolean canContinue() { return false; } /** * @return The last delta values. */ public double[] getLastDelta() { return this.lastDelta; } /** * @return The learning rate, this is value is essentially a percent. It is * the degree to which the gradients are applied to the weight * matrix to allow learning. */ @Override public double getLearningRate() { return this.learningRate; } /** * @return The momentum for training. This is the degree to which changes * from which the previous training iteration will affect this * training iteration. This can be useful to overcome local minima. */ @Override public double getMomentum() { return this.momentum; } /** * Determine if the specified continuation object is valid to resume with. * * @param state * The continuation object to check. * @return True if the specified continuation object is valid for this * training method and network. */ public boolean isValidResume(final TrainingContinuation state) { if (!state.getContents().containsKey(Backpropagation.LAST_DELTA)) { return false; } if (!state.getTrainingType().equals(getClass().getSimpleName())) { return false; } final double[] d = (double[]) state.get(Backpropagation.LAST_DELTA); return d.length == ((ContainsFlat) getMethod()).getFlat().getWeights().length; } /** * Pause the training. * * @return A training continuation object to continue with. */ @Override public TrainingContinuation pause() { final TrainingContinuation result = new TrainingContinuation(); result.setTrainingType(this.getClass().getSimpleName()); result.set(Backpropagation.LAST_DELTA, this.lastDelta); return result; } /** * Resume training. * * @param state * The training state to return to. */ @Override public void resume(final TrainingContinuation state) { if (!isValidResume(state)) { throw new TrainingError("Invalid training resume data length"); } this.lastDelta = ((double[]) state.get(Backpropagation.LAST_DELTA)); } /** * Set the learning rate, this is value is essentially a percent. It is the * degree to which the gradients are applied to the weight matrix to allow * learning. * * @param rate * The learning rate. */ @Override public void setLearningRate(final double rate) { this.learningRate = rate; } /** * Set the momentum for training. This is the degree to which changes from * which the previous training iteration will affect this training * iteration. This can be useful to overcome local minima. * * @param m * The momentum. */ @Override public void setMomentum(final double m) { this.momentum = m; } /** * Update a weight. * * @param gradients * The gradients. * @param lastGradient * The last gradients. * @param index * The index. * @return The weight delta. */ @Override public double updateWeight(final double[] gradients, final double[] lastGradient, final int index) { final double delta = (gradients[index] * this.learningRate) + (this.lastDelta[index] * this.momentum); this.lastDelta[index] = delta; return delta; } /** * Update a weight. * * @param gradients * The gradients. * @param lastGradient * The last gradients. * @param index * The index. * @param dropoutRate * The dropout rate. * @return The weight delta. */ @Override public double updateWeight(final double[] gradients, final double[] lastGradient, final int index, double dropoutRate) { if(this.nesterovUpdate) { return updateWeightNesterov(gradients,lastGradient,index,dropoutRate); } else { return updateWeightNormal(gradients,lastGradient,index,dropoutRate); } } /** * Update a weight. * * @param gradients * The gradients. * @param lastGradient * The last gradients. * @param index * The index. * @param dropoutRate * The dropout rate. * @return The weight delta. */ private double updateWeightNormal(final double[] gradients, final double[] lastGradient, final int index, double dropoutRate) { if (dropoutRate > 0 && dropoutRandomSource.nextDouble() < dropoutRate) { return 0; }; final double delta = (gradients[index] * this.learningRate) + (this.lastDelta[index] * this.momentum); this.lastDelta[index] = delta; return delta; } /** * Update a weight (Nesterov). * * @param gradients * The gradients. * @param lastGradient * The last gradients. * @param index * The index. * @param dropoutRate * The dropout rate. * @return The weight delta. */ private double updateWeightNesterov(final double[] gradients, final double[] lastGradient, final int index, double dropoutRate) { if (dropoutRate > 0 && dropoutRandomSource.nextDouble() < dropoutRate) { return 0; }; double prevNesterov = this.lastDelta[index]; this.lastDelta[index] = (this.momentum * prevNesterov) + (this.gradients[index] * this.learningRate); final double delta = (this.momentum * prevNesterov) - ((1+this.momentum)*this.lastDelta[index]); this.lastDelta[index] = delta; return delta; } /** * Perform training method specific init. */ public void initOthers() { } public boolean isNesterovUpdate() { return nesterovUpdate; } public void setNesterovUpdate(boolean nesterovUpdate) { this.nesterovUpdate = nesterovUpdate; } }