/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.neural.networks.training.lma; import org.encog.mathutil.error.ErrorCalculation; import org.encog.mathutil.matrices.decomposition.LUDecomposition; import org.encog.mathutil.matrices.hessian.ComputeHessian; import org.encog.mathutil.matrices.hessian.HessianCR; import org.encog.ml.MLMethod; import org.encog.ml.TrainingImplementationType; import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import org.encog.ml.data.basic.BasicMLData; import org.encog.ml.data.basic.BasicMLDataPair; import org.encog.ml.train.BasicTraining; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.structure.NetworkCODEC; import org.encog.neural.networks.training.TrainingError; import org.encog.neural.networks.training.propagation.TrainingContinuation; import org.encog.util.concurrency.MultiThreadable; import org.encog.util.validate.ValidateNetwork; /** * Trains a neural network using a Levenberg Marquardt algorithm (LMA). This * training technique is based on the mathematical technique of the same name. * * The LMA interpolates between the Gauss-Newton algorithm (GNA) and the * method of gradient descent (similar to what is used by backpropagation. * The lambda parameter determines the degree to which GNA and Gradient * Descent are used. A lower lambda results in heavier use of GNA, * whereas a higher lambda results in a heavier use of gradient descent. * * Each iteration starts with a low lambda that builds if the improvement * to the neural network is not desirable. At some point the lambda is * high enough that the training method reverts totally to gradient descent. * * This allows the neural network to be trained effectively in cases where GNA * provides the optimal training time, but has the ability to fall back to the * more primitive gradient descent method * * LMA finds only a local minimum, not a global minimum. * * References: * http://www.heatonresearch.com/wiki/LMA * http://en.wikipedia.org/wiki/Levenberg%E2%80%93Marquardt_algorithm * http://en.wikipedia.org/wiki/Finite_difference_method * http://crsouza.blogspot.com/2009/11/neural-network-learning-by-levenberg_18.html * http://mathworld.wolfram.com/FiniteDifference.html * http://www-alg.ist.hokudai.ac.jp/~jan/alpha.pdf - * http://www.inference.phy.cam.ac.uk/mackay/Bayes_FAQ.html * */ public class LevenbergMarquardtTraining extends BasicTraining implements MultiThreadable { /** * The amount to scale the lambda by. */ public static final double SCALE_LAMBDA = 10.0; /** * The max amount for the LAMBDA. */ public static final double LAMBDA_MAX = 1e25; /** * Utility class to compute the Hessian. */ private ComputeHessian hessian; /** * The network that is to be trained. */ private final BasicNetwork network; /** * The training set that we are using to train. */ private final MLDataSet indexableTraining; /** * The training set length. */ private final int trainingLength; /** * How many weights are we dealing with? */ private final int weightCount; /** * The neural network weights and bias values. */ private double[] weights; /** * The lambda, or damping factor. This is increased until a desirable * adjustment is found. */ private double lambda; /** * The diagonal of the hessian. */ private final double[] diagonal; /** * The amount to change the weights by. */ private double[] deltas; /** * The training elements. */ private final MLDataPair pair; /** * Is the init complete? */ private boolean initComplete; /** * Construct the LMA object. * * @param network * The network to train. Must have a single output neuron. * @param training * The training data to use. Must be indexable. */ public LevenbergMarquardtTraining(final BasicNetwork network, final MLDataSet training) { this(network,training,new HessianCR()); } /** * Construct the LMA object. * * @param network * The network to train. Must have a single output neuron. * @param training * The training data to use. Must be indexable. * @param h * Utility class to compute the Hessian. */ public LevenbergMarquardtTraining(final BasicNetwork network, final MLDataSet training, final ComputeHessian h) { super(TrainingImplementationType.Iterative); ValidateNetwork.validateMethodToData(network, training); setTraining(training); this.indexableTraining = getTraining(); this.network = network; this.trainingLength = (int) this.indexableTraining.getRecordCount(); this.weightCount = this.network.getStructure().calculateSize(); this.lambda = 0.1; this.deltas = new double[this.weightCount]; this.diagonal = new double[this.weightCount]; final BasicMLData input = new BasicMLData( this.indexableTraining.getInputSize()); final BasicMLData ideal = new BasicMLData( this.indexableTraining.getIdealSize()); this.pair = new BasicMLDataPair(input, ideal); this.hessian = h; } private void saveDiagonal() { double[][] h = this.hessian.getHessian(); for (int i = 0; i < this.weightCount; i++) { this.diagonal[i] = h[i][i]; } } @Override public boolean canContinue() { return false; } /** * @return The trained network. */ @Override public MLMethod getMethod() { return this.network; } /** * @return The SSE error with the current weights. */ private double calculateError() { ErrorCalculation result = new ErrorCalculation(); for (int i = 0; i < this.trainingLength; i++) { this.indexableTraining.getRecord(i, this.pair); final MLData actual = this.network.compute(this.pair.getInput()); result.updateError(actual.getData(), this.pair.getIdeal().getData(),pair.getSignificance()); } return result.calculateESS(); } private void applyLambda() { double[][] h = this.hessian.getHessian(); for (int i = 0; i < this.weightCount; i++) { h[i][i] = this.diagonal[i] + this.lambda; } } /** * Perform one iteration. */ @Override public void iteration() { if( !initComplete ) { this.hessian.init(network, getTraining()); this.initComplete = true; } LUDecomposition decomposition = null; preIteration(); this.hessian.clear(); this.weights = NetworkCODEC.networkToArray(this.network); this.hessian.compute(); double currentError = this.hessian.getSSE(); saveDiagonal(); final double startingError = currentError; boolean done = false; boolean singular; while (!done) { applyLambda(); decomposition = new LUDecomposition(this.hessian.getHessianMatrix()); singular = decomposition.isNonsingular(); if (singular) { this.deltas = decomposition.Solve(this.hessian.getGradients()); updateWeights(); currentError = calculateError(); } if ( !singular || currentError >= startingError) { this.lambda *= LevenbergMarquardtTraining.SCALE_LAMBDA; if( this.lambda> LevenbergMarquardtTraining.LAMBDA_MAX ) { this.lambda = LevenbergMarquardtTraining.LAMBDA_MAX; done = true; } } else { this.lambda /= LevenbergMarquardtTraining.SCALE_LAMBDA; done = true; } } setError(currentError); postIteration(); } /** * {@inheritDoc} */ @Override public TrainingContinuation pause() { return null; } /** * {@inheritDoc} */ @Override public void resume(final TrainingContinuation state) { } /** * Update the weights in the neural network. */ public void updateWeights() { final double[] w = this.weights.clone(); for (int i = 0; i < w.length; i++) { w[i] += this.deltas[i]; } NetworkCODEC.arrayToNetwork(w, this.network); } /** * @return The Hessian calculation method used. */ public ComputeHessian getHessian() { return hessian; } @Override public int getThreadCount() { if( this.hessian instanceof MultiThreadable ) { return ((MultiThreadable)this.hessian).getThreadCount(); } else { return 1; } } @Override public void setThreadCount(int numThreads) { if( this.hessian instanceof MultiThreadable ) { ((MultiThreadable)this.hessian).setThreadCount(numThreads); } else if(numThreads!=1 && numThreads!=0) { throw new TrainingError("The Hessian object in use("+this.hessian.getClass().toString()+") does not support multi-threaded mode."); } } }