/* * Encog(tm) Core v2.5 - Java Version * http://www.heatonresearch.com/encog/ * http://code.google.com/p/encog-java/ * Copyright 2008-2010 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.engine.network.train.gradient; import org.encog.engine.data.BasicEngineData; import org.encog.engine.data.EngineData; import org.encog.engine.data.EngineIndexableSet; import org.encog.engine.network.activation.ActivationFunction; import org.encog.engine.network.flat.FlatNetwork; import org.encog.engine.network.train.prop.TrainFlatNetworkProp; import org.encog.engine.util.EngineArray; import org.encog.engine.util.ErrorCalculation; import org.encog.engine.util.Stopwatch; /** * Worker class for the mulithreaded training of flat networks. */ public class GradientWorkerCPU implements FlatGradientWorker { /** * The network to train. */ private final FlatNetwork network; /** * The error calculation method. */ private final ErrorCalculation errorCalculation = new ErrorCalculation(); /** * The actual values from the neural network. */ private final double[] actual; /** * The deltas for each layer. */ private final double[] layerDelta; /** * The neuron counts, per layer. */ private final int[] layerCounts; /** * The feed counts, per layer. */ private final int[] layerFeedCounts; /** * The layer indexes. */ private final int[] layerIndex; /** * The index to each layer's weights and thresholds. */ private final int[] weightIndex; /** * The output from each layer. */ private final double[] layerOutput; /** * The gradients. */ private final double[] gradients; /** * The weights and thresholds. */ private final double[] weights; /** * The pair to use for training. */ private final EngineData pair; /** * The training data. */ private final EngineIndexableSet training; /** * The high end of the training data. */ private final int low; /** * The low end of the training. */ private final int high; /** * The owner. */ private final TrainFlatNetworkProp owner; /** * The elapsed time. */ private long elapsedTime; /** * The stopwatch, to evaluate performance. */ private final Stopwatch stopwatch; /** * Construct a gradient worker. * * @param network * The network to train. * @param owner * The owner that is doing the training. * @param training * The training data. * @param low * The low index to use in the training data. * @param high * The high index to use in the training data. */ public GradientWorkerCPU(final FlatNetwork network, final TrainFlatNetworkProp owner, final EngineIndexableSet training, final int low, final int high) { this.network = network; this.training = training; this.low = low; this.high = high; this.owner = owner; this.stopwatch = new Stopwatch(); this.layerDelta = new double[network.getLayerOutput().length]; this.gradients = new double[network.getWeights().length]; this.actual = new double[network.getOutputCount()]; this.weights = network.getWeights(); this.layerIndex = network.getLayerIndex(); this.layerCounts = network.getLayerCounts(); this.weightIndex = network.getWeightIndex(); this.layerOutput = network.getLayerOutput(); this.layerFeedCounts = network.getLayerFeedCounts(); this.pair = BasicEngineData.createPair(network.getInputCount(), network .getOutputCount()); } /** * @return Elapsed time for the last iteration. */ public long getElapsedTime() { return this.elapsedTime; } /** * @return The network training. */ @Override public FlatNetwork getNetwork() { return this.network; } /** * @return The weights for this network. */ public double[] getWeights() { return this.weights; } /** * Process one training set element. * * @param input * The network input. * @param ideal * The ideal values. */ private void process(final double[] input, final double[] ideal) { this.network.compute(input, this.actual); this.errorCalculation.updateError(this.actual, ideal); for (int i = 0; i < this.actual.length; i++) { this.layerDelta[i] = this.network.getActivationFunctions()[0] .derivativeFunction(this.actual[i]) * (ideal[i] - this.actual[i]); } for (int i = this.network.getBeginTraining(); i < this.network .getEndTraining(); i++) { processLevel(i); } } /** * Process one level. * * @param currentLevel * The level. */ private void processLevel(final int currentLevel) { final int fromLayerIndex = this.layerIndex[currentLevel + 1]; final int toLayerIndex = this.layerIndex[currentLevel]; final int fromLayerSize = this.layerCounts[currentLevel + 1]; final int toLayerSize = this.layerFeedCounts[currentLevel]; final int index = this.weightIndex[currentLevel]; final ActivationFunction activation = this.network .getActivationFunctions()[currentLevel + 1]; // handle weights int yi = fromLayerIndex; for (int y = 0; y < fromLayerSize; y++) { final double output = this.layerOutput[yi]; double sum = 0; int xi = toLayerIndex; int wi = index + y; for (int x = 0; x < toLayerSize; x++) { this.gradients[wi] += output * this.layerDelta[xi]; sum += this.weights[wi] * this.layerDelta[xi]; wi += fromLayerSize; xi++; } this.layerDelta[yi] = sum * activation.derivativeFunction(this.layerOutput[yi]); yi++; } } /** * Perform the gradient calculation for the specified index range. */ public void run() { try { this.stopwatch.reset(); this.stopwatch.start(); this.errorCalculation.reset(); for (int i = this.low; i <= this.high; i++) { this.training.getRecord(i, this.pair); process(this.pair.getInputArray(), this.pair.getIdealArray()); } final double error = this.errorCalculation.calculate(); this.owner.report(this.gradients, error, null); EngineArray.fill(this.gradients, 0); this.stopwatch.stop(); this.elapsedTime = this.stopwatch.getElapsedTicks(); } catch (final Throwable ex) { this.owner.report(null, 0, ex); } } }