/** * Copyright 2010 Neuroph Project http://neuroph.sourceforge.net * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neuroph.nnet.flat; import java.io.Serializable; import org.encog.engine.EncogEngineError; import org.encog.engine.data.EngineIndexableSet; import org.encog.engine.network.flat.FlatNetwork; import org.encog.engine.network.train.TrainFlatNetwork; import org.encog.engine.network.train.prop.TrainFlatNetworkBackPropagation; import org.encog.engine.network.train.prop.TrainFlatNetworkManhattan; import org.encog.engine.network.train.prop.TrainFlatNetworkResilient; import org.neuroph.core.NeuralNetwork; import org.neuroph.core.learning.SupervisedLearning; import org.neuroph.core.learning.TrainingSet; /** * Learning rule for flat networks. The learning type is specified by the the learningType * property. Currently the flat networks can be trained in the following three ways. * * Classic momentum based propagation. The learning rate and momentum are specified by using the * setLearningRate and setMomentum on the FlatNetworkLearning class. * * Resilient Propagation. No parameters are needed for this learning method. RPROP is the best * general purpose training method. * * Manhattan update rule. Not at all a good general purpose learning method, only useful in * some situations. A learning rate must be specified by using the setLearningRate method on * the FlatNetworkLearning class. * * * @author Jeff Heaton (http://www.jeffheaton.com) * */ public class FlatNetworkLearning extends SupervisedLearning implements Serializable { @Override protected void updatePatternError(double[] patternError) { throw new UnsupportedOperationException("Not supported yet."); } @Override protected void updateTotalNetworkError() { throw new UnsupportedOperationException("Not supported yet."); } /** * The object serial id. */ private static final long serialVersionUID = 1L; /** * The flat network to train. */ private FlatNetwork flat; /** * The last training set used. This property is used to tell if the training set has * changed since the last iteration. */ private transient EngineIndexableSet lastTrainingSet; /** * The last type of learning used. This property is used to tell if the training set has * changed since the last iteration. */ private FlatLearningType lastLearningType; /** * The flat trainer that is in use. */ private transient TrainFlatNetwork training; /** * The learning rate. This value is used for backprop and Manhattan. */ private double learningRate; /** * The momentum. This value is used for backpropagation training. */ private double momentum; /** * The number of threads to use. The default, zero, specifies to use the processor count * to determine an optimal number of threads. */ private int numThreads; /** * The type of flat learning to be used. */ private FlatLearningType learningType; /** * Constructor, create a flat network learning rule from the specified flat network. * @param flat The flat network. */ public FlatNetworkLearning(FlatNetwork flat) { init(flat); } /** * Constructor, create a flat network learning rule from a Neuroph network. This neuroph * network must have already had the flat network plugin installed. * @param network The neuroph network to use. */ public FlatNetworkLearning(NeuralNetwork network) { FlatNetworkPlugin plugin = (FlatNetworkPlugin) network .getPlugin(FlatNetworkPlugin.PLUGIN_NAME); FlatNetwork flat = plugin.getFlatNetwork(); if (this.flat == null) throw new EncogEngineError( "This learning rule only works with a network that has a FlatNetworkPlugin attached."); init(flat); } /** * Internal method used to setup the learning rule. * @param flat The network that will be trained. */ private void init(FlatNetwork flat) { this.flat = flat; this.learningType = FlatLearningType.ResilientPropagation; this.learningRate = 0.7; this.momentum = 0.3; } /** * This method is not used. * @param patternError Not used. */ @Override protected void updateNetworkWeights(double[] patternError) { throw new EncogEngineError( "Method (updateNetworkWeights) is unimplemented and should not have been called."); } /** * This method is not used. * @param patternError Not used. */ @Override protected void updateTotalNetworkError(double[] patternError) { throw new EncogEngineError( "Method (updateTotalNetworkError) is unimplemented and should not have been called."); } /** * Perform one learning epoch. * @param trainingSet The training set to use. */ @Override public void doLearningEpoch(TrainingSet trainingSet) { this.previousEpochError = this.totalNetworkError; // have we changed learning types, or training sets? If so, create a whole new trainer. if (this.lastLearningType != this.learningType || this.lastTrainingSet != trainingSet) { this.lastTrainingSet = trainingSet; switch( this.learningType ) { case ResilientPropagation: this.training = new TrainFlatNetworkResilient(this.flat, this.lastTrainingSet); break; case ManhattanUpdateRule: this.training = new TrainFlatNetworkManhattan(this.flat, this.lastTrainingSet, this.learningRate); break; case BackPropagation: this.training = new TrainFlatNetworkBackPropagation(this.flat, this.lastTrainingSet, this.getLearningRate(), this.getMomentum()); break; } this.training.setNumThreads(this.numThreads); // remember the last state of the learning type and trianing set so that we can // tell if this changes in the next iteration. this.lastLearningType = this.learningType; this.lastTrainingSet = trainingSet; } // perform a training iteration this.training.iteration(); this.totalNetworkError = this.training.getError(); // should Neuroph training stop if (hasReachedStopCondition()) { stopLearning(); } } /** * @return The learning rate. This value is used for Backprop and Manhattan training. */ public double getLearningRate() { return learningRate; } /** * Set the learning rate. This value is used for Backprop and Manhattan training. */ public void setLearningRate(double learningRate) { this.learningRate = learningRate; } /** * @return The momentum, this is is used for Backprop. */ public double getMomentum() { return momentum; } /** * Set the momentum, this is used for Backprop. * @param momentum */ public void setMomentum(double momentum) { this.momentum = momentum; } /** * @return The number of threads to use. Zero requests that the learning * rule choose a number of processors. */ public int getNumThreads() { return numThreads; } /** * Set the number of threads to use. * @param numThreads */ public void setNumThreads(int numThreads) { this.numThreads = numThreads; } /** * @return The type of learning to use. */ public FlatLearningType getLearningType() { return learningType; } /** * Set the type of learning to use. * @param learningType The type of learning to use. */ public void setLearningType(FlatLearningType learningType) { this.learningType = learningType; } }