/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.neural.networks.training.propagation;
import java.util.Random;
import org.encog.Encog;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.EngineTask;
/**
* Worker class for the mulithreaded training of flat networks.
*/
public class GradientWorker implements EngineTask {
/**
* Used to generate randomness for dropout
*/
protected Random dropoutRandomSource = new Random();
/**
* The network to train.
*/
private final FlatNetwork network;
/**
* The error calculation method.
*/
private final ErrorCalculation errorCalculation = new ErrorCalculation();
/**
* The actual values from the neural network.
*/
private final double[] actual;
/**
* The deltas for each layer.
*/
private final double[] layerDelta;
/**
* The neuron counts, per layer.
*/
private final int[] layerCounts;
/**
* The feed counts, per layer.
*/
private final int[] layerFeedCounts;
/**
* The layer indexes.
*/
private final int[] layerIndex;
/**
* The index to each layer's weights and thresholds.
*/
private final int[] weightIndex;
/**
* The output from each layer.
*/
private final double[] layerOutput;
/**
* The sums.
*/
private final double[] layerSums;
/**
* The gradients.
*/
private final double[] gradients;
/**
* The weights and thresholds.
*/
private final double[] weights;
/**
* The pair to use for training.
*/
private final MLDataPair pair;
/**
* The training data.
*/
private final MLDataSet training;
/**
* The high end of the training data.
*/
private final int low;
/**
* The low end of the training.
*/
private final int high;
/**
* The owner.
*/
private final GradientWorkerOwner owner;
/**
* Derivative add constant. Used to combat flat spot.
*/
private double[] flatSpot;
/**
* The error function to use.
*/
private final ErrorFunction errorFunction;
private double[] layerDropoutRates;
/**
* Construct a gradient worker.
*
* @param theNetwork
* The network to train.
* @param theOwner
* The owner that is doing the training.
* @param theTraining
* The training data.
* @param theLow
* The low index to use in the training data.
* @param theHigh
* The high index to use in the training data.
* @param flatSpot The flatspot additions for each layer
* @param ef Error function
*/
public GradientWorker(final FlatNetwork theNetwork,
final GradientWorkerOwner theOwner,
final MLDataSet theTraining, final int theLow,
final int theHigh, final double[] flatSpot,
ErrorFunction ef) {
this.network = theNetwork;
this.training = theTraining;
this.low = theLow;
this.high = theHigh;
this.owner = theOwner;
this.flatSpot = flatSpot;
this.errorFunction = ef;
this.layerDelta = new double[network.getLayerOutput().length];
this.gradients = new double[network.getWeights().length];
this.actual = new double[network.getOutputCount()];
this.weights = network.getWeights();
this.layerIndex = network.getLayerIndex();
this.layerCounts = network.getLayerCounts();
this.layerDropoutRates = network.getLayerDropoutRates();
this.weightIndex = network.getWeightIndex();
this.layerOutput = network.getLayerOutput();
this.layerSums = network.getLayerSums();
this.layerFeedCounts = network.getLayerFeedCounts();
this.pair = BasicMLDataPair.createPair(network.getInputCount(), network
.getOutputCount());
}
/**
* @return The network being processed.
*/
public FlatNetwork getNetwork() {
return this.network;
}
/**
* @return The weights for this network.
*/
public double[] getWeights() {
return this.weights;
}
/**
* Process one training set element.
*
* @param pair the training data information
*/
public void process(final MLDataPair pair) {
this.network.compute(pair.getInputArray(), this.actual);
this.errorCalculation.updateError(this.actual, pair.getIdealArray(), pair.getSignificance());
// Calculate error for the output layer.
this.errorFunction.calculateError(
this.network.getActivationFunctions()[0], this.layerSums,this.layerOutput,
pair.getIdeal().getData(), this.actual, this.layerDelta, this.flatSpot[0],
pair.getSignificance());
// Apply regularization, if requested.
if( this.owner.getL1()>Encog.DEFAULT_DOUBLE_EQUAL
|| this.owner.getL1()>Encog.DEFAULT_DOUBLE_EQUAL ) {
double[] lp = new double[2];
calculateRegularizationPenalty(lp);
for(int i=0;i<this.actual.length;i++) {
double p = (lp[0]*this.owner.getL1()) + (lp[1]*this.owner.getL2());
this.layerDelta[i]+=p;
}
}
// Propagate backwards (chain rule from calculus).
for (int i = this.network.getBeginTraining(); i < this.network
.getEndTraining(); i++) {
processLevel(i);
}
}
/**
* Process one level.
*
* @param currentLevel
* The level.
*/
private void processLevel(final int currentLevel) {
final int fromLayerIndex = this.layerIndex[currentLevel + 1];
final int toLayerIndex = this.layerIndex[currentLevel];
final int fromLayerSize = this.layerCounts[currentLevel + 1];
final int toLayerSize = this.layerFeedCounts[currentLevel];
double dropoutRate = 0;
if(this.layerDropoutRates.length > currentLevel && this.layerDropoutRates[currentLevel] != 0) {
dropoutRate = this.layerDropoutRates[currentLevel];
}
final int index = this.weightIndex[currentLevel];
final ActivationFunction activation = this.network
.getActivationFunctions()[currentLevel];
final double currentFlatSpot = this.flatSpot[currentLevel + 1];
// handle weights
// array references are made method local to avoid one indirection
final double[] layerDelta = this.layerDelta;
final double[] weights = this.weights;
final double[] gradients = this.gradients;
final double[] layerOutput = this.layerOutput;
final double[] layerSums = this.layerSums;
int yi = fromLayerIndex;
for (int y = 0; y < fromLayerSize; y++) {
final double output = layerOutput[yi];
double sum = 0;
int wi = index + y;
final int loopEnd = toLayerIndex+toLayerSize;
if(dropoutRate == 0 || dropoutRandomSource.nextDouble() > dropoutRate)
{
for (int xi = toLayerIndex; xi < loopEnd; xi++, wi += fromLayerSize) {
gradients[wi] += output * layerDelta[xi];
sum += weights[wi] * layerDelta[xi];
}
layerDelta[yi] = sum
* (activation.derivativeFunction(layerSums[yi], layerOutput[yi])+currentFlatSpot);
} else {
layerDelta[yi] = 0;
}
yi++;
}
}
/**
* Perform the gradient calculation for the specified index range.
*/
public final void run() {
try {
this.errorCalculation.reset();
for (int i = this.low; i <= this.high; i++) {
this.training.getRecord(i, this.pair);
process(pair);
}
final double error = this.errorCalculation.calculate();
this.owner.report(this.gradients, error, null);
EngineArray.fill(this.gradients, 0);
} catch (final Throwable ex) {
this.owner.report(null, 0, ex);
}
}
public final void run(int index) {
this.training.getRecord(index, this.pair);
process(pair);
this.owner.report(this.gradients, 0, null);
EngineArray.fill(this.gradients, 0);
}
public ErrorCalculation getErrorCalculation() {
return errorCalculation;
}
/**
* @return the gradients
*/
public double[] getGradients() {
return gradients;
}
public void calculateRegularizationPenalty(double[] l) {
for (int i = 0; i < network.getLayerCounts().length - 1; i++) {
layerRegularizationPenalty(i, l);
}
}
public void layerRegularizationPenalty(final int fromLayer, final double[] l) {
final int fromCount = network.getLayerTotalNeuronCount(fromLayer);
final int toCount = network.getLayerNeuronCount(fromLayer + 1);
for (int fromNeuron = 0; fromNeuron < fromCount; fromNeuron++) {
for (int toNeuron = 0; toNeuron < toCount; toNeuron++) {
double w = this.network.getWeight(fromLayer, fromNeuron, toNeuron);
l[0]+=Math.abs(w);
l[1]+=w*w;
}
}
}
}