/** * Copyright 2010 Neuroph Project http://neuroph.sourceforge.net * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neuroph.nnet.learning; import org.neuroph.core.learning.TrainingData; import java.io.Serializable; import java.util.HashMap; import org.neuroph.core.Connection; import org.neuroph.core.Layer; import org.neuroph.core.NeuralNetwork; import org.neuroph.core.Neuron; import org.neuroph.core.Weight; import org.neuroph.core.learning.SupervisedLearning; import org.neuroph.core.learning.TrainingSet; /** * LMS learning rule for neural networks. * * @author Zoran Sevarac <sevarac@gmail.com> */ public class LMS extends SupervisedLearning implements Serializable { /** * The class fingerprint that is set to indicate serialization * compatibility with a previous version of the class. */ private static final long serialVersionUID = 1L; /** * Setting to determine if learning (weights update) is in batch mode * False by default. */ private boolean batchMode = false; /** * Training data buffer is used to store various values during the training. * This value determines the buffer capacity. */ protected int trainingDataBufferSize = 2; /** * Creates new LMS learning rule */ public LMS() { super(); } /** * Calculates and updates sum of squared errors for single pattern, and updates total sum of squared pattern errors * * @param patternError * single pattern error vector */ // see: http://www.vni.com/products/imsl/documentation/CNL06/stat/NetHelp/default.htm?turl=multilayerfeedforwardneuralnetworks.htm protected void updatePatternError(double[] patternErrorVector) { this.patternErrorSqrSum = 0; for (double error : patternErrorVector) { this.patternErrorSqrSum += (error * error) * 0.5; } this.totalPatternErrorSqrSum += this.patternErrorSqrSum; } @Override protected void updateTotalNetworkError() { this.totalNetworkError = totalPatternErrorSqrSum /(this.getTrainingSet().size()); } /** * This method implements weight update procedure for the whole network for * this learning rule * * @param patternError * single pattern error vector */ @Override protected void updateNetworkWeights(double[] patternError) { int i = 0; for (Neuron neuron : neuralNetwork.getOutputNeurons()) { neuron.setError(patternError[i]); this.updateNeuronWeights(neuron); i++; } } /** * This method implements weights update procedure for the single neuron * * @param neuron * neuron to update weights */ protected void updateNeuronWeights(Neuron neuron) { // get the error for specified neuron, double neuronError = neuron.getError(); // tanh can be used to minimise the impact of big error values, which can cause network instability // suggested at https://sourceforge.net/tracker/?func=detail&atid=1107579&aid=3130561&group_id=238532 // double neuronError = Math.tanh(neuron.getError()); // iterate through all neuron's input connections for (Connection connection : neuron.getInputConnections()) { // get the input from current connection double input = connection.getInput(); // calculate the weight change double deltaWeight = this.learningRate * neuronError * input; // update the weight change this.applyWeightChange(connection.getWeight(), deltaWeight); } } /** * Returns true if learning is performed in batch mode, false otherwise * @return true if learning is performed in batch mode, false otherwise */ public boolean isBatchMode() { return batchMode; } /** * Sets batch mode on/off (true/false) * @param batchMode batch mode setting */ public void setBatchMode(boolean batchMode) { this.batchMode = batchMode; } /** * This method does one learning epoch (one pass through training set) * and after that does weight update if learning is in batch mode * @param trainingSet */ @Override public void doLearningEpoch(TrainingSet trainingSet) { super.doLearningEpoch(trainingSet); if (this.batchMode == true) { // if learning is performed in batch mode, also apply accumulated weight changes from this epoch batchModeWeightsUpdate(); } } /** * This methods first checks to see if learning is performed in online or batch mode. * If learning is in online mode weight change is applied immediately. * If learning is in batch mode all weight changes during one epoch are summed * in trainingData buffer, and that sum is applied after each epoch. * @param weight weight that should be changed * @param deltaWeight weight change */ protected void applyWeightChange(Weight weight, double deltaWeight) { if (this.batchMode == false) { // if not in batch mode apply the weight change immediately weight.inc(deltaWeight); } else { // accumulate Weight change if its in batch mode double deltaWeightSum = weight.getTrainingData().get(TrainingData.DELTA_WEIGHT_SUM) + deltaWeight; weight.getTrainingData().set(TrainingData.DELTA_WEIGHT_SUM, deltaWeightSum); } } /** * This method updates network weights in batch mode - use accumulated weights change stored in trainingData * buffers to update all weights in network. It is executed after each epoch if learning is in batch mode. */ protected void batchModeWeightsUpdate() { // iterate layers from output to input for (int i = neuralNetwork.getLayersCount() - 1; i > 0; i--) { Layer layer = neuralNetwork.getLayers().get(i); // iterate neurons at each layer for (Neuron neuron : layer.getNeurons()) { // iterate connections/weights for each neuron for (Connection connection : neuron.getInputConnections()) { // for each connection weight apply accumulated weight change Weight weight = connection.getWeight(); // get deltaWeightSum double deltaWeightSum = weight.getTrainingData().get(TrainingData.DELTA_WEIGHT_SUM); // apply the deltaWeightSum weight.inc(deltaWeightSum); // reset the deltaWeightSum to prepare it for next epoch weight.getTrainingData().set(TrainingData.DELTA_WEIGHT_SUM, 0); } } } } @Override public void setNeuralNetwork(NeuralNetwork neuralNetwork) { super.setNeuralNetwork(neuralNetwork); this.initTrainingDataBuffer(); } /** * This method initializes training data buffers in all network weights. * It can be overridden to create bigger training data buffer for each weight. */ protected void initTrainingDataBuffer() { for (int i = neuralNetwork.getLayersCount() - 1; i > 0; i--) { Layer layer = neuralNetwork.getLayers().get(i); for (Neuron neuron : layer.getNeurons()) { for (Connection connection : neuron.getInputConnections()) { Weight weight = connection.getWeight(); weight.initTrainingDataBuffer(this.trainingDataBufferSize); } } } } public int getTrainingDataBufferSize() { return trainingDataBufferSize; } final public void setTrainingDataBufferSize(int trainingDataBufferSize) { this.trainingDataBufferSize = trainingDataBufferSize; } /** * * @param patternError * @deprecated */ @Override protected void updateTotalNetworkError(double[] patternError) { double sqrErrorSum = 0; for (double error : patternError) { sqrErrorSum += (error * error); } this.totalNetworkError += sqrErrorSum / (2 * patternError.length); } }