/**
* Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neuroph.core.learning;
import java.io.Serializable;
import java.util.Iterator;
// TODO: random pattern order
/**
* Base class for all supervised learning algorithms.
* It extends IterativeLearning, and provides general supervised learning principles.
*
* @author Zoran Sevarac <sevarac@gmail.com>
*/
abstract public class SupervisedLearning extends IterativeLearning implements
Serializable {
/**
* The class fingerprint that is set to indicate serialization
* compatibility with a previous version of the class
*/
private static final long serialVersionUID = 3L;
/**
* Total network error
*/
protected transient double totalNetworkError;
/**
* Sum of squared errors of each output neuron for one pattern 0.5*sum((et-e)^2)
*/
protected transient double patternErrorSqrSum;
/**
* Totoal sum of all pattern errors
*/
protected transient double totalPatternErrorSqrSum;
/**
* Total network error in previous epoch
*/
protected transient double previousEpochError;
/**
* Max allowed network error (condition to stop learning)
*/
protected double maxError = 0.01d;
/**
* Stopping condition: training stops if total network error change is smaller than minErrorChange
* for minErrorChangeIterationsLimit number of iterations
*/
private transient double minErrorChange = Double.POSITIVE_INFINITY;
/**
* Stopping condition: training stops if total network error change is smaller than minErrorChange
* for minErrorChangeStopIterations number of iterations
*/
private transient int minErrorChangeIterationsLimit = Integer.MAX_VALUE;
/**
* Count iterations where error change is smaller then minErrorChange
*/
private transient int minErrorChangeIterationsCount;
/**
* Creates new supervised learning rule
*/
public SupervisedLearning() {
super();
}
/**
* Trains network for the specified training set and number of iterations
* @param trainingSet training set to learn
* @param maxError maximum number of iterations to learn
*
*/
public void learn(TrainingSet trainingSet, double maxError) {
this.maxError = maxError;
this.learn(trainingSet);
}
/**
* Trains network for the specified training set and number of iterations
* @param trainingSet training set to learn
* @param maxIterations maximum number of learning iterations
*
*/
public void learn(TrainingSet trainingSet, double maxError, int maxIterations) {
this.maxError = maxError;
this.setMaxIterations(maxIterations);
this.learn(trainingSet);
}
@Override
protected void reset() {
super.reset();
this.minErrorChangeIterationsCount = 0;
this.totalNetworkError = 0d;
this.previousEpochError = 0d;
}
/**
* This method implements basic logic for one learning epoch for the
* supervised learning algorithms. Epoch is the one pass through the
* training set. This method iterates through the training set
* and trains network for each element. It also sets flag if conditions
* to stop learning has been reached: network error below some allowed
* value, or maximum iteration count
*
* @param trainingSet
* training set for training network
*/
@Override
public void doLearningEpoch(TrainingSet trainingSet) {
this.previousEpochError = this.totalNetworkError;
this.totalNetworkError = 0d;
this.totalPatternErrorSqrSum = 0d;
Iterator<TrainingElement> iterator = trainingSet.iterator();
while (iterator.hasNext() && !isStopped()) {
SupervisedTrainingElement supervisedTrainingElement = (SupervisedTrainingElement)iterator.next();
this.learnPattern(supervisedTrainingElement);
}
this.updateTotalNetworkError();
// moved stopping condition to separate method hasReachedStopCondition() so it can be overriden / customized in subclasses
if (hasReachedStopCondition()) {
stopLearning();
}
}
/**
* Returns true if stop condition has been reached, false otherwise.
* Override this method in derived classes to implement custom stop criteria.
*
* @return true if stop condition is reached, false otherwise
*/
protected boolean hasReachedStopCondition() {
// da li ovd etreba staviti da proverava i da li se koristi ovaj uslov??? ili staviti da uslov bude automatski samo s ajaako malom vrednoscu za errorChange Doule.minvalue
return (this.totalNetworkError < this.maxError) || this.errorChangeStalled();
}
/**
* Returns true if absolute error change is sufficently small (<=minErrorChange) for minErrorChangeStopIterations number of iterations
* @return true if absolute error change is stalled (error is sufficently small for some number of iterations)
*/
protected boolean errorChangeStalled() {
double absErrorChange = Math.abs(previousEpochError - totalNetworkError);
if (absErrorChange <= this.minErrorChange) {
this.minErrorChangeIterationsCount++;
if (this.minErrorChangeIterationsCount >= this.minErrorChangeIterationsLimit) {
return true;
}
} else {
this.minErrorChangeIterationsCount = 0;
}
return false;
}
/**
* Trains network with the pattern from the specified training element
*
* @param trainingElement
* supervised training element which contains input and desired
* output
*/
protected void learnPattern(SupervisedTrainingElement trainingElement) {
double[] input = trainingElement.getInput();
this.neuralNetwork.setInput(input);
this.neuralNetwork.calculate();
double[] output = this.neuralNetwork.getOutput();
double[] desiredOutput = trainingElement.getDesiredOutput();
double[] patternError = this.getPatternError(output, desiredOutput);
this.updatePatternError(patternError);
this.updateNetworkWeights(patternError);
}
/**
* Calculates the network error for the current pattern - diference between
* desired and actual output
*
* @param output
* actual network output
* @param desiredOutput
* desired network output
* @return pattern error
*/
protected double[] getPatternError(double[] output, double[] desiredOutput) {
double[] patternError = new double[output.length];
for(int i = 0; i < output.length; i++) {
patternError[i] = desiredOutput[i] - output[i];
}
return patternError;
}
/**
* Sets allowed network error, which indicates when to stopLearning training
*
* @param maxError
* network error
*/
public void setMaxError(double maxError) {
this.maxError = maxError;
}
/**
* Returns learning error tolerance - the value of total network error to stop learning.
*
* @return learning error tolerance
*/
public double getMaxError() {
return maxError;
}
/**
* Returns total network error in current learning epoch
*
* @return total network error in current learning epoch
*/
public synchronized double getTotalNetworkError() {
return totalNetworkError;
}
/**
* Returns total network error in previous learning epoch
*
* @return total network error in previous learning epoch
*/
public double getPreviousEpochError() {
return previousEpochError;
}
/**
* Returns min error change stopping criteria
*
* @return min error change stopping criteria
*/
public double getMinErrorChange() {
return minErrorChange;
}
/**
* Sets min error change stopping criteria
*
* @param minErrorChange value for min error change stopping criteria
*/
public void setMinErrorChange(double minErrorChange) {
this.minErrorChange = minErrorChange;
}
/**
* Returns number of iterations for min error change stopping criteria
*
* @return number of iterations for min error change stopping criteria
*/
public int getMinErrorChangeIterationsLimit() {
return minErrorChangeIterationsLimit;
}
/**
* Sets number of iterations for min error change stopping criteria
* @param minErrorChangeIterationsLimit number of iterations for min error change stopping criteria
*/
public void setMinErrorChangeIterationsLimit(int minErrorChangeIterationsLimit) {
this.minErrorChangeIterationsLimit = minErrorChangeIterationsLimit;
}
/**
* Returns number of iterations count for for min error change stopping criteria
*
* @return number of iterations count for for min error change stopping criteria
*/
public int getMinErrorChangeIterationsCount() {
return minErrorChangeIterationsCount;
}
/**
* Subclasses update total network error for each training pattern with this
* method. Error update formula is learning rule specific.
* @deprecated
*/
abstract protected void updateTotalNetworkError(double[] patternError);
/**
* This method should implement the weights update procedure
*
* @param patternError
* pattern error vector
*/
abstract protected void updateNetworkWeights(double[] patternError);
/**
* This method should calculate sum sqr for single pattern error and
* update the error sum for all patterns.
* @param patternErrorVector Error vector for pattern
*/
abstract protected void updatePatternError(double[] patternErrorVector);
/**
* This method should calculate the MSE or other type or total network error used
*/
abstract protected void updateTotalNetworkError();
}