/* * Encog(tm) Core v2.5 - Java Version * http://www.heatonresearch.com/encog/ * http://code.google.com/p/encog-java/ * Copyright 2008-2010 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.neural.networks.training.propagation; import org.encog.EncogError; import org.encog.engine.network.flat.FlatNetwork; import org.encog.engine.network.train.TrainFlatNetwork; import org.encog.engine.network.train.prop.OpenCLTrainingProfile; import org.encog.neural.data.NeuralDataSet; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.structure.FlatUpdateNeeded; import org.encog.neural.networks.training.BasicTraining; import org.encog.neural.networks.training.TrainingError; import org.encog.util.EncogValidate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Implements basic functionality that is needed by each of the propagation * methods. The specifics of each of the propagation methods is implemented * inside of the PropagationMethod interface implementors. * * @author jheaton * */ public abstract class Propagation extends BasicTraining { /** * The network. */ private final BasicNetwork network; /** * The current flat network we are using for training, or null for none. */ private FlatNetwork currentFlatNetwork; /** * The current flat trainer we are using, or null for none. */ private TrainFlatNetwork flatTraining; /** * The logging object. */ @SuppressWarnings("unused") private final Logger logger = LoggerFactory.getLogger(this.getClass()); /** * Construct a propagation object. * * @param network * The network. * @param training * The training set. */ public Propagation(final BasicNetwork network, final NeuralDataSet training) { super(); this.network = network; setTraining(training); } /** * @return True if this training can be continued. */ public boolean canContinue() { return false; } /** * Should be called after training has completed and the iteration method * will not be called any further. */ @Override public void finishTraining() { super.finishTraining(); this.network.getStructure().updateFlatNetwork(); this.flatTraining.finishTraining(); } /** * @return the currentFlatNetwork */ public FlatNetwork getCurrentFlatNetwork() { return this.currentFlatNetwork; } /** * @return the flatTraining */ public TrainFlatNetwork getFlatTraining() { return this.flatTraining; } /** * @return The network. */ public BasicNetwork getNetwork() { return this.network; } /** * @return The number of threads. */ public int getNumThreads() { return this.flatTraining.getNumThreads(); } /** * @return The OpenCL device to use, or null for the CPU. */ public OpenCLTrainingProfile getProfile() { return null; } /** * Determine if this specified training continuation object is valid for * this training method. * * @param state * The training continuation object to check. * @return True if the continuation object is valid. */ public boolean isValidResume(final TrainingContinuation state) { return false; } /** * Perform one training iteration. */ public void iteration() { try { this.network.getStructure().updateFlatNetwork(); preIteration(); this.flatTraining.iteration(); setError(this.flatTraining.getError()); this.network.getStructure().setFlatUpdate( FlatUpdateNeeded.Unflatten); postIteration(); if (this.logger.isInfoEnabled()) { this.logger.info("Training iteration done, error: " + getError()); } } catch (final ArrayIndexOutOfBoundsException ex) { EncogValidate.validateNetworkForTraining(this.network, getTraining()); throw new EncogError(ex); } } /** * Perform the specified number of training iterations. This can be more * efficient than single training iterations. This is particularly true if * you are training with a GPU. * * @param count * The number of training iterations. */ @Override public void iteration(final int count) { try { preIteration(); this.flatTraining.iteration(count); setIteration(this.flatTraining.getIteration()); setError(this.flatTraining.getError()); this.network.getStructure().setFlatUpdate( FlatUpdateNeeded.Unflatten); postIteration(); if (this.logger.isInfoEnabled()) { this.logger.info("Training iterations done, error: " + getError()); } } catch (final ArrayIndexOutOfBoundsException ex) { EncogValidate.validateNetworkForTraining(this.network, getTraining()); throw new EncogError(ex); } } /** * Pause the training to continue later. * * @return A training continuation object. */ public TrainingContinuation pause() { throw new TrainingError("This training type does not support pause."); } /** * Resume training. * * @param state * The training continuation object to use to continue. */ public void resume(final TrainingContinuation state) { throw new TrainingError("This training type does not support resume."); } /** * @param flatTraining * the flatTraining to set */ public void setFlatTraining(final TrainFlatNetwork flatTraining) { this.flatTraining = flatTraining; } /** * Set the number of threads. Specify zero to tell Encog to automatically * determine the best number of threads for the processor. If OpenCL is used * as the target device, then this value is not used. * * @param numThreads * The number of threads. */ public void setNumThreads(final int numThreads) { this.flatTraining.setNumThreads(numThreads); } }