/** * Copyright 2010 Neuroph Project http://neuroph.sourceforge.net * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neuroph.nnet.learning; import org.neuroph.core.NeuralNetwork; import org.neuroph.core.learning.TrainingSet; /** * Backpropagation learning rule with dynamic learning rate and momentum * @author Zoran Sevarac <sevarac@gmail.com> */ public class DynamicBackPropagation extends MomentumBackpropagation{ private static final long serialVersionUID = 1L; private double maxLearningRate = 0.9d; private double minLearningRate = 0.1d; private double learningRateChange = 0.99926d; private boolean useDynamicLearningRate = true; private double maxMomentum = 0.9d; private double minMomentum = 0.1d; private double momentumChange = 0.99926d; private boolean useDynamicMomentum = true; // private double previousNetworkError; public DynamicBackPropagation() { super(); } // Adjusting learning rate dynamically /* If network error of current epoch is higher than the network error of the previous * epoch the learning rate is adjusted by minus 1 per cent of current learning rate. * Otherwise the learning rate is adjusted by plus 1 per cent of current learning * rate. So, learning rate increases faster than decreasing does. But if learning rate * reaches 0.9 it switches back to 0.5 to avoid endless training. The lowest learning * rate is 0.5 also to avoid endless training. */ protected void adjustLearningRate() { // 1. First approach - probably the best // bigger error -> smaller learning rate; minimize the error growth // smaller error -> bigger learning rate; converege faster // the amount of earning rate change is proportional to error change - by using errorChange double errorChange = this.previousEpochError - this.totalNetworkError; this.learningRate = this.learningRate + (errorChange*learningRateChange); if (this.learningRate > this.maxLearningRate) this.learningRate = this.maxLearningRate; if (this.learningRate < this.minLearningRate) this.learningRate = this.minLearningRate; // System.out.println("Learning rate: "+this.learningRate); // 2. Second approach // doing this lineary for each epoch considering network error behaviour // probbaly the worst one /* if (this.totalNetworkError >= this.totalNetworkErrorInPreviousEpoch) { this.learningRate = this.learningRate * this.learningRateChange; if (this.learningRate < this.minLearningRate) this.learningRate = this.minLearningRate; } else { this.learningRate = this.learningRate * (1 + (1 - this.learningRateChange)); // *1.01 if (this.learningRate > this.maxLearningRate) this.learningRate = this.maxLearningRate; } */ // third approach used by sharky nn // By default It starts with ni = 0,9, and after each epoch ni is changed by: 0,99977 ^ N // where N is number of points, and ^ is power. // ni = ni * 0,99977 ^ N // this one drops the learning rate too fast // this.learningRate = this.learningRate * Math.pow(learningRateChange, this.getTrainingSet().size()); // if (this.learningRate > this.maxLearningRate) // this.learningRate = this.maxLearningRate; // // if (this.learningRate < this.minLearningRate) // this.learningRate = this.minLearningRate; // System.out.println("Iteration: "+currentIteration + " Learning rate: "+ this.learningRate); // one more approach suggested at https://sourceforge.net/tracker/?func=detail&atid=1107579&aid=3130561&group_id=238532 // if (this.totalNetworkError >= this.previousEpochError) { // // If going wrong way, drop to minimum learning and work our way back up. // // This way we accelerate as we improve. // learningRate=minLearningRate; // } else { // this.learningRate = this.learningRate * (1 + (1 - this.learningRateChange)); // *1.01 // // if (this.learningRate > this.maxLearningRate) // this.learningRate = this.maxLearningRate; // // } } protected void adjustMomentum() { double errorChange = this.previousEpochError - this.totalNetworkError; this.momentum = this.momentum + (errorChange*momentumChange); if (this.momentum > this.maxMomentum) this.momentum = this.maxMomentum; if (this.momentum < this.minMomentum) this.momentum = this.minMomentum; // one more approach suggested at https://sourceforge.net/tracker/?func=detail&atid=1107579&aid=3130561&group_id=238532 // Probably want to drop momentum to minimum value. // if (this.totalNetworkError >= this.previousEpochError) { // momentum = momentum * momentumChange; // if (momentum < minMomentum) momentum = minMomentum; // } else { // momentum = momentum * (1 + (1 - momentumChange)); // *1.01 // if (momentum > maxMomentum) momentum = maxMomentum; // } } @Override public void doLearningEpoch(TrainingSet trainingSet) { super.doLearningEpoch(trainingSet); if (currentIteration > 0) { if (useDynamicLearningRate) adjustLearningRate(); if (useDynamicMomentum) adjustMomentum(); } } public double getLearningRateChange() { return learningRateChange; } public void setLearningRateChange(double learningRateChange) { this.learningRateChange = learningRateChange; } public double getMaxLearningRate() { return maxLearningRate; } public void setMaxLearningRate(double maxLearningRate) { this.maxLearningRate = maxLearningRate; } public double getMaxMomentum() { return maxMomentum; } public void setMaxMomentum(double maxMomentum) { this.maxMomentum = maxMomentum; } public double getMinLearningRate() { return minLearningRate; } public void setMinLearningRate(double minLearningRate) { this.minLearningRate = minLearningRate; } public double getMinMomentum() { return minMomentum; } public void setMinMomentum(double minMomentum) { this.minMomentum = minMomentum; } public double getMomentumChange() { return momentumChange; } public void setMomentumChange(double momentumChange) { this.momentumChange = momentumChange; } public boolean getUseDynamicLearningRate() { return useDynamicLearningRate; } public void setUseDynamicLearningRate(boolean useDynamicLearningRate) { this.useDynamicLearningRate = useDynamicLearningRate; } public boolean getUseDynamicMomentum() { return useDynamicMomentum; } public void setUseDynamicMomentum(boolean useDynamicMomentum) { this.useDynamicMomentum = useDynamicMomentum; } }