/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.neural.networks.training.propagation.sgd;
import org.encog.Encog;
import org.encog.EncogError;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.mathutil.randomize.generate.GenerateRandom;
import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.error.CrossEntropyErrorFunction;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.Momentum;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.networks.training.propagation.sgd.update.AdamUpdate;
import org.encog.neural.networks.training.propagation.sgd.update.UpdateRule;
import org.encog.util.EngineArray;
public class StochasticGradientDescent extends BasicTraining implements Momentum,
LearningRate {
/**
* The learning rate.
*/
private double learningRate;
/**
* The momentum.
*/
private double momentum;
/**
* The gradients.
*/
private final double[] gradients;
/**
* The deltas for each layer.
*/
private final double[] layerDelta;
/**
* L1 regularization.
*/
private double l1;
/**
* L2 regularization.
*/
private double l2;
/**
* The update rule to use.
*/
private UpdateRule updateRule = new AdamUpdate();
/**
* The last delta values.
*/
private double[] lastDelta;
/**
* A flat neural network.
*/
private FlatNetwork flat;
/**
* The error function to use.
*/
private ErrorFunction errorFunction = new CrossEntropyErrorFunction();
/**
* The error calculation.
*/
private ErrorCalculation errorCalculation;
private GenerateRandom rnd;
private MLMethod method;
public StochasticGradientDescent(final ContainsFlat network,
final MLDataSet training) {
this(network,training,new MersenneTwisterGenerateRandom());
}
public StochasticGradientDescent(final ContainsFlat network,
final MLDataSet training, final GenerateRandom theRandom) {
super(TrainingImplementationType.Iterative);
setTraining(training);
if( !(training instanceof BatchDataSet) ) {
setBatchSize(25);
}
this.method = network;
this.flat = network.getFlat();
this.layerDelta = new double[this.flat.getLayerOutput().length];
this.gradients = new double[this.flat.getWeights().length];
this.errorCalculation = new ErrorCalculation();
this.rnd = theRandom;
this.learningRate = 0.001;
this.momentum = 0.9;
}
public void process(final MLDataPair pair) {
errorCalculation = new ErrorCalculation();
double[] actual = new double[this.flat.getOutputCount()];
flat.compute(pair.getInputArray(), actual);
errorCalculation.updateError(actual, pair.getIdealArray(), pair.getSignificance());
// Calculate error for the output layer.
this.errorFunction.calculateError(
flat.getActivationFunctions()[0], this.flat.getLayerSums(),this.flat.getLayerOutput(),
pair.getIdeal().getData(), actual, this.layerDelta, 0,
pair.getSignificance());
// Apply regularization, if requested.
if( this.l1> Encog.DEFAULT_DOUBLE_EQUAL
|| this.l2>Encog.DEFAULT_DOUBLE_EQUAL ) {
double[] lp = new double[2];
calculateRegularizationPenalty(lp);
for(int i=0;i<actual.length;i++) {
double p = (lp[0]*this.l1) + (lp[1]*this.l2);
this.layerDelta[i]+=p;
}
}
// Propagate backwards (chain rule from calculus).
for (int i = this.flat.getBeginTraining(); i < this.flat
.getEndTraining(); i++) {
processLevel(i);
}
}
public void update() {
if( getIteration()==0 ) {
this.updateRule.init(this);
}
preIteration();
this.updateRule.update(this.gradients,this.flat.getWeights());
setError(this.errorCalculation.calculate());
postIteration();
EngineArray.fill(this.gradients,0);
this.errorCalculation.reset();
if( getTraining() instanceof BatchDataSet) {
((BatchDataSet)getTraining()).advance();
}
}
public void resetError() {
this.errorCalculation.reset();
}
private void processLevel(final int currentLevel) {
final int fromLayerIndex = flat.getLayerIndex()[currentLevel + 1];
final int toLayerIndex = flat.getLayerIndex()[currentLevel];
final int fromLayerSize = flat.getLayerCounts()[currentLevel + 1];
final int toLayerSize = flat.getLayerFeedCounts()[currentLevel];
double dropoutRate = 0;
final int index = this.flat.getWeightIndex()[currentLevel];
final ActivationFunction activation = this.flat
.getActivationFunctions()[currentLevel];
// handle weights
// array references are made method local to avoid one indirection
final double[] layerDelta = this.layerDelta;
final double[] weights = this.flat.getWeights();
final double[] gradients = this.gradients;
final double[] layerOutput = this.flat.getLayerOutput();
final double[] layerSums = this.flat.getLayerSums();
int yi = fromLayerIndex;
for (int y = 0; y < fromLayerSize; y++) {
final double output = layerOutput[yi];
double sum = 0;
int wi = index + y;
final int loopEnd = toLayerIndex+toLayerSize;
for (int xi = toLayerIndex; xi < loopEnd; xi++, wi += fromLayerSize) {
gradients[wi] += output * layerDelta[xi];
sum += weights[wi] * layerDelta[xi];
}
layerDelta[yi] = sum
* (activation.derivativeFunction(layerSums[yi], layerOutput[yi]));
yi++;
}
}
@Override
public void iteration() {
for(int i=0;i<getTraining().size();i++) {
process(getTraining().get(i));
}
if( getIteration()==0 ) {
this.updateRule.init(this);
}
preIteration();
update();
postIteration();
if( getTraining() instanceof BatchDataSet) {
((BatchDataSet)getTraining()).advance();
}
}
@Override
public boolean canContinue() {
return false;
}
@Override
public double getLearningRate() {
return this.learningRate;
}
@Override
public double getMomentum() {
return this.momentum;
}
public boolean isValidResume(final TrainingContinuation state) {
return false;
}
/**
* Pause the training.
*
* @return A training continuation object to continue with.
*/
@Override
public TrainingContinuation pause() {
return null;
}
@Override
public void resume(final TrainingContinuation state) {
throw new EncogError("Resume not currently supported.");
}
@Override
public MLMethod getMethod() {
return this.method;
}
@Override
public void setLearningRate(final double rate) {
this.learningRate = rate;
}
@Override
public void setMomentum(final double m) {
this.momentum = m;
}
public void preIteration() {
super.preIteration();
}
public int getBatchSize() {
if( getTraining() instanceof BatchDataSet ) {
return ((BatchDataSet)getTraining()).getBatchSize();
} else {
return 0;
}
}
public void setBatchSize(int theBatchSize) {
if( getTraining() instanceof BatchDataSet ) {
((BatchDataSet)getTraining()).setBatchSize(theBatchSize);
} else {
BatchDataSet batchSet = new BatchDataSet(getTraining(),this.rnd);
setTraining(batchSet);
}
}
public double getL1() {
return l1;
}
public void setL1(double l1) {
this.l1 = l1;
}
public double getL2() {
return l2;
}
public void setL2(double l2) {
this.l2 = l2;
}
public void calculateRegularizationPenalty(double[] l) {
for (int i = 0; i < this.flat.getLayerCounts().length - 1; i++) {
layerRegularizationPenalty(i, l);
}
}
public void layerRegularizationPenalty(final int fromLayer, final double[] l) {
final int fromCount = this.flat.getLayerTotalNeuronCount(fromLayer);
final int toCount = this.flat.getLayerNeuronCount(fromLayer + 1);
for (int fromNeuron = 0; fromNeuron < fromCount; fromNeuron++) {
for (int toNeuron = 0; toNeuron < toCount; toNeuron++) {
double w = this.flat.getWeight(fromLayer, fromNeuron, toNeuron);
l[0]+=Math.abs(w);
l[1]+=w*w;
}
}
}
public FlatNetwork getFlat() {
return this.flat;
}
public UpdateRule getUpdateRule() {
return updateRule;
}
public void setUpdateRule(UpdateRule updateRule) {
this.updateRule = updateRule;
}
}