/* * Encog(tm) Core v2.5 - Java Version * http://www.heatonresearch.com/encog/ * http://code.google.com/p/encog-java/ * Copyright 2008-2010 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.engine.network.train.prop; import org.encog.engine.EncogEngineError; import org.encog.engine.concurrency.DetermineWorkload; import org.encog.engine.concurrency.EngineConcurrency; import org.encog.engine.concurrency.TaskGroup; import org.encog.engine.data.EngineDataSet; import org.encog.engine.data.EngineIndexableSet; import org.encog.engine.network.flat.FlatNetwork; import org.encog.engine.network.train.TrainFlatNetwork; import org.encog.engine.network.train.gradient.FlatGradientWorker; import org.encog.engine.network.train.gradient.GradientWorkerCPU; import org.encog.engine.util.EngineArray; import org.encog.engine.util.IntRange; /** * Train a flat network using multithreading, and GPU support. * * The training data must be indexable, it will be broken into groups for each * thread to process. * * At the end of each iteration the training from each thread is aggregated back * to the neural network. * */ public abstract class TrainFlatNetworkProp implements TrainFlatNetwork { /** * The number of threads to use. */ private int numThreads; /** * The gradients. */ protected double[] gradients; /** * The last gradients, from the last training iteration. */ private double[] lastGradient; /** * The network to train. */ protected final FlatNetwork network; /** * The training data. */ private final EngineDataSet training; /** * The network in indexable form. */ private final EngineIndexableSet indexable; /** * The workers. */ private FlatGradientWorker[] workers; /** * The total error. Used to take the average of. */ private double totalError; /** * The current error is the average error over all of the threads. */ protected double currentError; /** * Reported exception from the threads. */ private Throwable reportedException; /** * The iteration. */ private int iteration; /** * Train a flat network multithreaded. * * @param network * The network to train. * @param training * The training data to use. */ public TrainFlatNetworkProp(final FlatNetwork network, final EngineDataSet training) { if (!(training instanceof EngineIndexableSet)) { throw new EncogEngineError( "Training data must be Indexable for this training type."); } this.training = training; this.network = network; this.gradients = new double[this.network.getWeights().length]; this.lastGradient = new double[this.network.getWeights().length]; this.indexable = (EngineIndexableSet) training; this.numThreads = 0; this.reportedException = null; } /** * Calculate the gradients. */ public void calculateGradients() { if (this.workers == null) { init(); } this.workers[0].getNetwork().clearContext(); this.totalError = 0; if (this.workers.length > 1) { final TaskGroup group = EngineConcurrency.getInstance() .createTaskGroup(); for (final FlatGradientWorker worker : this.workers) { EngineConcurrency.getInstance().processTask(worker, group); } group.waitForComplete(); } else { this.workers[0].run(); } this.currentError = this.totalError / this.workers.length; } /** * Copy the contexts to keep them consistent with multithreaded training. */ private void copyContexts() { // copy the contexts(layer outputO from each group to the next group for (int i = 0; i < (this.workers.length - 1); i++) { final double[] src = this.workers[i].getNetwork().getLayerOutput(); final double[] dst = this.workers[i + 1].getNetwork() .getLayerOutput(); EngineArray.arrayCopy(src, dst); } // copy the contexts from the final group to the real network EngineArray.arrayCopy( this.workers[this.workers.length - 1].getNetwork().getLayerOutput(), this.network.getLayerOutput()); } /** * {@inheritDoc} */ public void finishTraining() { // nothing to do } /** * {@inheritDoc} */ public double getError() { return this.currentError; } /** * @return The gradients from the last iteration; */ public double[] getLastGradient() { return this.lastGradient; } /** * {@inheritDoc} */ public FlatNetwork getNetwork() { return this.network; } /** * {@inheritDoc} */ public int getNumThreads() { return this.numThreads; } /** * {@inheritDoc} */ public EngineDataSet getTraining() { return this.training; } /** * Init the process. */ private void init() { final DetermineWorkload determine = new DetermineWorkload( this.numThreads, (int) this.indexable.getRecordCount()); this.workers = new FlatGradientWorker[determine.getThreadCount()]; int index = 0; // handle CPU for (final IntRange r : determine.calculateWorkers()) { this.workers[index++] = new GradientWorkerCPU(this.network.clone(), this, this.indexable.openAdditional(), r.getLow(), r.getHigh()); } } /** * {@inheritDoc} */ public void iteration() { this.iteration++; calculateGradients(); if (this.network.isLimited()) { learnLimited(); } else { learn(); } for (final FlatGradientWorker worker : this.workers) { EngineArray.arrayCopy(this.network.getWeights(), 0, worker.getWeights(), 0, this.network.getWeights().length); } copyContexts(); if (this.reportedException != null) { throw (new EncogEngineError(this.reportedException)); } } /** * Apply and learn. */ protected void learn() { final double[] weights = this.network.getWeights(); for (int i = 0; i < this.gradients.length; i++) { weights[i] += updateWeight(this.gradients, this.lastGradient, i); this.gradients[i] = 0; } } /** * Apply and learn. This is the same as learn, but it checks to see if any * of the weights are below the limit threshold. In this case, these weights * are zeroed out. Having two methods allows the regular learn method, which * is what is usually use, to be as fast as possible. */ protected void learnLimited() { final double limit = this.network.getConnectionLimit(); final double[] weights = this.network.getWeights(); for (int i = 0; i < this.gradients.length; i++) { if (weights[i] < limit) { weights[i] = 0; } else { weights[i] += updateWeight(this.gradients, this.lastGradient, i); } this.gradients[i] = 0; } } /** * Called by the worker threads to report the progress at each step. * * @param gradients * The gradients from that worker. * @param error * The error for that worker. * @param ex * The exception. */ public void report(final double[] gradients, final double error, final Throwable ex) { synchronized (this) { if (ex == null) { for (int i = 0; i < gradients.length; i++) { this.gradients[i] += gradients[i]; } this.totalError += error; } else { this.reportedException = ex; } } } /** * {@inheritDoc} */ public void setNumThreads(final int numThreads) { this.numThreads = numThreads; } /** * Update a weight, the means by which weights are updated vary depending on * the training. * * @param gradients * The gradients. * @param lastGradient * The last gradients. * @param index * The index. * @return The update value. */ public abstract double updateWeight(double[] gradients, double[] lastGradient, int index); /** * Perform the specified number of training iterations. This is a basic * implementation that just calls iteration the specified number of times. * However, some training methods, particularly with the GPU, benefit * greatly by calling with higher numbers than 1. * * @param count * The number of training iterations. */ public void iteration(final int count) { for (int i = 0; i < count; i++) { iteration(); } } /** * {@inheritDoc} */ public int getIteration() { return this.iteration; } /** * {@inheritDoc} */ public void setIteration(final int iteration) { this.iteration = iteration; } }