/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.mathutil.matrices.hessian;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.flat.FlatNetwork;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.EngineTask;
/**
* A threaded worker that is used to calculate the first derivatives of the
* output of the neural network. These values are ultimatly used to calculate
* the Hessian.
*
*/
public class ChainRuleWorker implements EngineTask {
/**
* The actual values from the neural network.
*/
private double[] actual;
/**
* The deltas for each layer.
*/
private double[] layerDelta;
/**
* The neuron counts, per layer.
*/
private int[] layerCounts;
/**
* The feed counts, per layer.
*/
private int[] layerFeedCounts;
/**
* The layer indexes.
*/
private int[] layerIndex;
/**
* The index to each layer's weights and thresholds.
*/
private int[] weightIndex;
/**
* The output from each layer.
*/
private double[] layerOutput;
/**
* The sums.
*/
private double[] layerSums;
/**
* The weights and thresholds.
*/
private double[] weights;
/**
* The flat network.
*/
private FlatNetwork flat;
/**
* The training data.
*/
private MLDataSet training;
/**
* The output neuron to calculate for.
*/
private int outputNeuron;
/**
* The total first derivatives.
*/
private double[] totDeriv;
/**
* The gradients.
*/
private double[] gradients;
/**
* The error.
*/
private double error;
/**
* The low range.
*/
private int low;
/**
* The high range.
*/
private int high;
/**
* The pair to use for training.
*/
private final MLDataPair pair;
/**
* The weight count.
*/
private int weightCount;
/**
* The hessian for this worker.
*/
private double[][] hessian;
/**
* Construct the chain rule worker.
* @param theNetwork The network to calculate a Hessian for.
* @param theTraining The training data.
* @param theLow The low range.
* @param theHigh The high range.
*/
public ChainRuleWorker(FlatNetwork theNetwork, MLDataSet theTraining, int theLow, int theHigh) {
this.weightCount = theNetwork.getWeights().length;
this.hessian = new double[this.weightCount][this.weightCount];
this.training = theTraining;
this.flat = theNetwork;
this.layerDelta = new double[flat.getLayerOutput().length];
this.actual = new double[flat.getOutputCount()];
this.totDeriv = new double[weightCount];
this.gradients = new double[weightCount];
this.weights = flat.getWeights();
this.layerIndex = flat.getLayerIndex();
this.layerCounts = flat.getLayerCounts();
this.weightIndex = flat.getWeightIndex();
this.layerOutput = flat.getLayerOutput();
this.layerSums = flat.getLayerSums();
this.layerFeedCounts = flat.getLayerFeedCounts();
this.low = theLow;
this.high = theHigh;
this.pair = BasicMLDataPair.createPair(flat.getInputCount(), flat
.getOutputCount());
}
/**
* {@inheritDoc}
*/
@Override
public void run() {
this.error = 0;
EngineArray.fill(this.hessian, 0);
EngineArray.fill(this.totDeriv, 0);
EngineArray.fill(this.gradients, 0);
double[] derivative = new double[this.weightCount];
// Loop over every training element
for (int i = this.low; i <= this.high; i++) {
this.training.getRecord(i, this.pair);
EngineArray.fill(derivative, 0);
process(outputNeuron, derivative, pair.getInputArray(), pair.getIdealArray());
}
}
/**
* Process one training set element.
*
* @param input
* The network input.
* @param ideal
* The ideal values.
*/
private void process(int outputNeuron, double[] derivative, final double[] input, final double[] ideal) {
this.flat.compute(input, this.actual);
double e = ideal[outputNeuron] - this.actual[outputNeuron];
this.error+=e*e;
for (int i = 0; i < this.actual.length; i++) {
if (i == outputNeuron) {
this.layerDelta[i] = this.flat.getActivationFunctions()[0]
.derivativeFunction(this.layerSums[i],
this.layerOutput[i]);
} else {
this.layerDelta[i] = 0;
}
}
for (int i = this.flat.getBeginTraining(); i < this.flat.getEndTraining(); i++) {
processLevel(i,derivative);
}
// calculate gradients
for (int j = 0; j < this.weights.length; j++) {
this.gradients[j] += e * derivative[j];
totDeriv[j] += derivative[j];
}
// update hessian
for(int i=0;i<this.weightCount;i++) {
for(int j=0;j<this.weightCount;j++) {
this.hessian[i][j]+=derivative[i]*derivative[j];
}
}
}
/**
* Process one level.
*
* @param currentLevel
* The level.
*/
private void processLevel(final int currentLevel, double[] derivative) {
final int fromLayerIndex = this.layerIndex[currentLevel + 1];
final int toLayerIndex = this.layerIndex[currentLevel];
final int fromLayerSize = this.layerCounts[currentLevel + 1];
final int toLayerSize = this.layerFeedCounts[currentLevel];
final int index = this.weightIndex[currentLevel];
final ActivationFunction activation = this.flat
.getActivationFunctions()[currentLevel + 1];
// handle weights
int yi = fromLayerIndex;
for (int y = 0; y < fromLayerSize; y++) {
final double output = this.layerOutput[yi];
double sum = 0;
int xi = toLayerIndex;
int wi = index + y;
for (int x = 0; x < toLayerSize; x++) {
derivative[wi] += output * this.layerDelta[xi];
sum += this.weights[wi] * this.layerDelta[xi];
wi += fromLayerSize;
xi++;
}
this.layerDelta[yi] = sum
* (activation.derivativeFunction(this.layerSums[yi],this.layerOutput[yi]));
yi++;
}
}
/**
* @return the outputNeuron
*/
public int getOutputNeuron() {
return outputNeuron;
}
/**
* @param outputNeuron the outputNeuron to set
*/
public void setOutputNeuron(int outputNeuron) {
this.outputNeuron = outputNeuron;
}
/**
* @return The first derivatives, used to calculate the Hessian.
*/
public double[] getDerivative() {
return this.totDeriv;
}
/**
* @return the gradients
*/
public double[] getGradients() {
return gradients;
}
/**
* @return The SSE error.
*/
public double getError() {
return this.error;
}
/**
* @return The flat network.
*/
public FlatNetwork getNetwork() {
return this.flat;
}
/**
* @return the hessian
*/
public double[][] getHessian() {
return hessian;
}
}