/* * Encog(tm) Core v2.5 - Java Version * http://www.heatonresearch.com/encog/ * http://code.google.com/p/encog-java/ * Copyright 2008-2010 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.neural.networks.training.lma; import org.encog.mathutil.matrices.Matrix; import org.encog.mathutil.matrices.decomposition.LUDecomposition; import org.encog.neural.data.Indexable; import org.encog.neural.data.NeuralData; import org.encog.neural.data.NeuralDataPair; import org.encog.neural.data.NeuralDataSet; import org.encog.neural.data.basic.BasicNeuralData; import org.encog.neural.data.basic.BasicNeuralDataPair; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.layers.Layer; import org.encog.neural.networks.structure.NetworkCODEC; import org.encog.neural.networks.training.BasicTraining; import org.encog.neural.networks.training.TrainingError; /** * Trains a neural network using a Levenberg Marquardt algorithm (LMA). This * training technique is based on the mathematical technique of the same name. * * http://en.wikipedia.org/wiki/Levenberg%E2%80%93Marquardt_algorithm * * The LMA training technique has some important limitations that you should be * aware of, before using it. * * Only neural networks that have a single output neuron can be used with this * training technique. * * The entire training set must be loaded into memory. Because of this an * Indexable training set must be used. * * However, despite these limitations, the LMA training technique can be a very * effective training method. * * References: - http://www-alg.ist.hokudai.ac.jp/~jan/alpha.pdf - * http://www.inference.phy.cam.ac.uk/mackay/Bayes_FAQ.html * */ public class LevenbergMarquardtTraining extends BasicTraining { /** * The amount to scale the lambda by. */ public static final double SCALE_LAMBDA = 10.0; /** * The max amount for the LAMBDA. */ public static final double LAMBDA_MAX = 1e25; /** * Return the sum of the diagonal. * * @param m * The matrix to sum. * @return The trace of the matrix. */ public static double trace(final double[][] m) { double result = 0.0; for (int i = 0; i < m.length; i++) { result += m[i][i]; } return result; } /** * The network that is to be trained. */ private final BasicNetwork network; /** * The training set that we are using to train. */ private final Indexable indexableTraining; /** * The training set length. */ private final int trainingLength; /** * The number of "parameters" in the LMA algorithm. The parameters are what * the LMA adjusts to achieve the desired outcome. For neural network * optimization, the parameters are the weights and bias values. */ private final int parametersLength; /** * The neural network weights and bias values. */ private double[] weights; /** * The "hessian" matrix, used by the LMA. */ private final Matrix hessianMatrix; /** * The "hessian" matrix as a 2d array. */ private final double[][] hessian; /** * The alpha is multiplied by sum squared of weights. This scales the effect * that the sum squared of the weights has. */ private double alpha; /** * The beta is multiplied by the sum squared of the errors. */ private double beta; /** * The lambda, or damping factor. This is increased until a desirable * adjustment is found. */ private double lambda; /** * The calculated gradients. */ private final double[] gradient; /** * The diagonal of the hessian. */ private final double[] diagonal; /** * The amount to change the weights by. */ private double[] deltas; /** * Gamma, used for Bayesian regularization. */ private double gamma; /** * The training elements. */ private final NeuralDataPair pair; /** * Should we use Bayesian regularization. */ private boolean useBayesianRegularization; /** * Construct the LMA object. * * @param network * The network to train. Must have a single output neuron. * @param training * The training data to use. Must be indexable. */ public LevenbergMarquardtTraining(final BasicNetwork network, final NeuralDataSet training) { if (!(training instanceof Indexable)) { throw new TrainingError( "Levenberg Marquardt requires an indexable training set."); } final Layer outputLayer = network.getLayer(BasicNetwork.TAG_OUTPUT); if (outputLayer == null) { throw new TrainingError( "Levenberg Marquardt requires an output layer."); } if (outputLayer.getNeuronCount() != 1) { throw new TrainingError( "Levenberg Marquardt requires an output layer with a single neuron."); } setTraining(training); this.indexableTraining = (Indexable) getTraining(); this.network = network; this.trainingLength = (int) this.indexableTraining.getRecordCount(); this.parametersLength = this.network.getStructure().calculateSize(); this.hessianMatrix = new Matrix(this.parametersLength, this.parametersLength); this.hessian = this.hessianMatrix.getData(); this.alpha = 0.0; this.beta = 1.0; this.lambda = 0.1; this.deltas = new double[this.parametersLength]; this.gradient = new double[this.parametersLength]; this.diagonal = new double[this.parametersLength]; final BasicNeuralData input = new BasicNeuralData( this.indexableTraining.getInputSize()); final BasicNeuralData ideal = new BasicNeuralData( this.indexableTraining.getIdealSize()); this.pair = new BasicNeuralDataPair(input, ideal); } /** * Calculate the Hessian matrix. * * @param jacobian * The Jacobian matrix. * @param errors * The errors. */ public void calculateHessian(final double[][] jacobian, final double[] errors) { for (int i = 0; i < this.parametersLength; i++) { // Compute Jacobian Matrix Errors double s = 0.0; for (int j = 0; j < this.trainingLength; j++) { s += jacobian[j][i] * errors[j]; } this.gradient[i] = s; // Compute Quasi-Hessian Matrix using Jacobian (H = J'J) for (int j = 0; j < this.parametersLength; j++) { double c = 0.0; for (int k = 0; k < this.trainingLength; k++) { c += jacobian[k][i] * jacobian[k][j]; } this.hessian[i][j] = this.beta * c; } } for (int i = 0; i < this.parametersLength; i++) { this.diagonal[i] = this.hessian[i][i]; } } /** * Calculate the sum squared of the weights. * * @return The sum squared of the weights. */ private double calculateSumOfSquaredWeights() { double result = 0; for (final double weight : this.weights) { result += weight * weight; } return result / 2.0; } /** * @return The trained network. */ public BasicNetwork getNetwork() { return this.network; } /** * @return True, if Bayesian regularization is used. */ public boolean isUseBayesianRegularization() { return this.useBayesianRegularization; } /** * Perform one iteration. */ public void iteration() { LUDecomposition decomposition = null; double trace = 0; preIteration(); this.weights = NetworkCODEC.networkToArray(this.network); final ComputeJacobian j = new JacobianChainRule(this.network, this.indexableTraining); double sumOfSquaredErrors = j.calculate(this.weights); double sumOfSquaredWeights = calculateSumOfSquaredWeights(); // this.setError(j.getError()); calculateHessian(j.getJacobian(), j.getRowErrors()); // Define the objective function // bayesian regularization objective function final double objective = this.beta * sumOfSquaredErrors + this.alpha * sumOfSquaredWeights; double current = objective + 1.0; // Start the main Levenberg-Macquardt method this.lambda /= LevenbergMarquardtTraining.SCALE_LAMBDA; // We'll try to find a direction with less error // (or where the objective function is smaller) while ((current >= objective) && (this.lambda < LevenbergMarquardtTraining.LAMBDA_MAX)) { this.lambda *= LevenbergMarquardtTraining.SCALE_LAMBDA; // Update diagonal (Levenberg-Marquardt formula) for (int i = 0; i < this.parametersLength; i++) { this.hessian[i][i] = this.diagonal[i] + (this.lambda + this.alpha); } // Decompose to solve the linear system decomposition = new LUDecomposition(this.hessianMatrix); // Check if the Jacobian has become non-invertible if (!decomposition.isNonsingular()) { continue; } // Solve using LU (or SVD) decomposition this.deltas = decomposition.Solve(this.gradient); // Update weights using the calculated deltas sumOfSquaredWeights = updateWeights(); // Calculate the new error sumOfSquaredErrors = 0.0; for (int i = 0; i < this.trainingLength; i++) { this.indexableTraining.getRecord(i, this.pair); final NeuralData actual = this.network.compute(this.pair .getInput()); final double e = this.pair.getIdeal().getData(0) - actual.getData(0); sumOfSquaredErrors += e * e; } sumOfSquaredErrors /= 2.0; // Update the objective function current = this.beta * sumOfSquaredErrors + this.alpha * sumOfSquaredWeights; // If the object function is bigger than before, the method // is tried again using a greater dumping factor. } // If this iteration caused a error drop, then next iteration // will use a smaller damping factor. this.lambda /= LevenbergMarquardtTraining.SCALE_LAMBDA; if (this.useBayesianRegularization && (decomposition != null)) { // Compute the trace for the inverse Hessian trace = LevenbergMarquardtTraining.trace(decomposition.inverse()); // Poland update's formula: this.gamma = this.parametersLength - (this.alpha * trace); this.alpha = this.parametersLength / (2.0 * sumOfSquaredWeights + trace); this.beta = Math.abs((this.trainingLength - this.gamma) / (2.0 * sumOfSquaredErrors)); } setError(sumOfSquaredErrors); postIteration(); } /** * Set if Bayesian regularization should be used. * @param useBayesianRegularization True to use Bayesian regularization. */ public void setUseBayesianRegularization( final boolean useBayesianRegularization) { this.useBayesianRegularization = useBayesianRegularization; } /** * Update the weights. * * @return The sum squared of the weights. */ public double updateWeights() { double result = 0; final double[] w = this.weights.clone(); for (int i = 0; i < w.length; i++) { w[i] += this.deltas[i]; result += w[i] * w[i]; } NetworkCODEC.arrayToNetwork(w, this.network); return result / 2.0; } }