/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.neural.freeform.training; import java.io.Serializable; import java.util.HashSet; import java.util.Set; import org.encog.engine.network.activation.ActivationSigmoid; import org.encog.mathutil.error.ErrorCalculation; import org.encog.ml.MLMethod; import org.encog.ml.TrainingImplementationType; import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import org.encog.ml.train.BasicTraining; import org.encog.neural.freeform.FreeformConnection; import org.encog.neural.freeform.FreeformNetwork; import org.encog.neural.freeform.FreeformNeuron; import org.encog.neural.freeform.task.ConnectionTask; /** * Provides basic propagation functions to other trainers. */ public abstract class FreeformPropagationTraining extends BasicTraining implements Serializable { /** * The serial ID. */ private static final long serialVersionUID = 1L; /** * The constant to use to fix the flat spot problem. */ public static final double FLAT_SPOT_CONST = 0.1; /** * The network that we are training. */ private final FreeformNetwork network; /** * The training set to use. */ private final MLDataSet training; /** * The number of iterations. */ private int iterationCount; /** * The error at the beginning of the last iteration. */ private double error; /** * The neurons that have been visited. */ private final Set<FreeformNeuron> visited = new HashSet<FreeformNeuron>(); /** * Are we fixing the flat spot problem? (default = true) */ private boolean fixFlatSopt = true; /** * The batch size. Specify 1 for pure online training. Specify 0 for pure * batch training (complete training set in one batch). Otherwise specify * the batch size for batch training. */ private int batchSize = 0; /** * Don't use this constructor, it is for serialization only. */ public FreeformPropagationTraining() { super(TrainingImplementationType.Iterative); this.network = null; this.training = null; } /** * Construct the trainer. * @param theNetwork The network to train. * @param theTraining The training data. */ public FreeformPropagationTraining(final FreeformNetwork theNetwork, final MLDataSet theTraining) { super(TrainingImplementationType.Iterative); this.network = theNetwork; this.training = theTraining; } /** * Calculate the gradient for a neuron. * @param toNeuron The neuron to calculate for. */ private void calculateNeuronGradient(final FreeformNeuron toNeuron) { // Only calculate if layer has inputs, because we've already handled the // output // neurons, this means a hidden layer. if (toNeuron.getInputSummation() != null) { // between the layer deltas between toNeuron and the neurons that // feed toNeuron. // also calculate all inbound gradeints to toNeuron for (final FreeformConnection connection : toNeuron .getInputSummation().list()) { // calculate the gradient final double gradient = connection.getSource().getActivation() * toNeuron.getTempTraining(0); connection.addTempTraining(0, gradient); // calculate the next layer delta final FreeformNeuron fromNeuron = connection.getSource(); double sum = 0; for (final FreeformConnection toConnection : fromNeuron .getOutputs()) { sum += toConnection.getTarget().getTempTraining(0) * toConnection.getWeight(); } final double neuronOutput = fromNeuron.getActivation(); final double neuronSum = fromNeuron.getSum(); double deriv = toNeuron.getInputSummation() .getActivationFunction() .derivativeFunction(neuronSum, neuronOutput); if (this.fixFlatSopt && (toNeuron.getInputSummation() .getActivationFunction() instanceof ActivationSigmoid)) { deriv += FreeformPropagationTraining.FLAT_SPOT_CONST; } final double layerDelta = sum * deriv; fromNeuron.setTempTraining(0, layerDelta); } // recurse to the next level for (final FreeformConnection connection : toNeuron .getInputSummation().list()) { final FreeformNeuron fromNeuron = connection.getSource(); calculateNeuronGradient(fromNeuron); } } } /** * Calculate the output delta for a neuron, given its difference. * Only used for output neurons. * @param neuron * @param diff */ private void calculateOutputDelta(final FreeformNeuron neuron, final double diff) { final double neuronOutput = neuron.getActivation(); final double neuronSum = neuron.getInputSummation().getSum(); double deriv = neuron.getInputSummation().getActivationFunction() .derivativeFunction(neuronSum, neuronOutput); if (this.fixFlatSopt && (neuron.getInputSummation().getActivationFunction() instanceof ActivationSigmoid)) { deriv += FreeformPropagationTraining.FLAT_SPOT_CONST; } final double layerDelta = deriv * diff; neuron.setTempTraining(0, layerDelta); } /** * {@inheritDoc} */ @Override public boolean canContinue() { return false; } /** * {@inheritDoc} */ @Override public void finishTraining() { this.network.tempTrainingClear(); } /** * {@inheritDoc} */ @Override public double getError() { return this.error; } /** * {@inheritDoc} */ @Override public TrainingImplementationType getImplementationType() { return TrainingImplementationType.Iterative; } /** * {@inheritDoc} */ @Override public int getIteration() { return this.iterationCount; } /** * {@inheritDoc} */ @Override public MLMethod getMethod() { return this.network; } /** * {@inheritDoc} */ @Override public MLDataSet getTraining() { return this.training; } /** * @return True, if we are fixing the flat spot problem. */ public boolean isFixFlatSopt() { return this.fixFlatSopt; } /** * {@inheritDoc} */ @Override public void iteration() { preIteration(); this.iterationCount++; this.network.clearContext(); if (this.batchSize == 0) { processPureBatch(); } else { processBatches(); } postIteration(); } /** * {@inheritDoc} */ @Override public void iteration(final int count) { for (int i = 0; i < count; i++) { this.iteration(); } } /** * Process training for pure batch mode (one single batch). */ protected void processPureBatch() { final ErrorCalculation errorCalc = new ErrorCalculation(); this.visited.clear(); for (final MLDataPair pair : this.training) { final MLData input = pair.getInput(); final MLData ideal = pair.getIdeal(); final MLData actual = this.network.compute(input); final double sig = pair.getSignificance(); errorCalc.updateError(actual.getData(), ideal.getData(), sig); for (int i = 0; i < this.network.getOutputCount(); i++) { final double diff = (ideal.getData(i) - actual.getData(i)) * sig; final FreeformNeuron neuron = this.network.getOutputLayer() .getNeurons().get(i); calculateOutputDelta(neuron, diff); calculateNeuronGradient(neuron); } } // Set the overall error. setError(errorCalc.calculate()); // Learn for all data. learn(); } /** * Process training batches. */ protected void processBatches() { int lastLearn = 0; final ErrorCalculation errorCalc = new ErrorCalculation(); this.visited.clear(); for (final MLDataPair pair : this.training) { final MLData input = pair.getInput(); final MLData ideal = pair.getIdeal(); final MLData actual = this.network.compute(input); final double sig = pair.getSignificance(); errorCalc.updateError(actual.getData(), ideal.getData(), sig); for (int i = 0; i < this.network.getOutputCount(); i++) { final double diff = (ideal.getData(i) - actual.getData(i)) * sig; final FreeformNeuron neuron = this.network.getOutputLayer() .getNeurons().get(i); calculateOutputDelta(neuron, diff); calculateNeuronGradient(neuron); } // Are we at the end of a batch. lastLearn++; if( lastLearn>=this.batchSize ) { lastLearn = 0; learn(); } } // Handle any remaining data. if( lastLearn>0 ) { learn(); } // Set the overall error. setError(errorCalc.calculate()); } /** * Learn for the entire network. */ protected void learn() { this.network.performConnectionTask(new ConnectionTask() { @Override public void task(final FreeformConnection connection) { learnConnection(connection); connection.setTempTraining(0, 0); } }); } /** * Learn for a single connection. * @param connection The connection to learn from. */ protected abstract void learnConnection(FreeformConnection connection); /** * {@inheritDoc} */ @Override public void setError(final double theError) { this.error = theError; } /** * Set if we should fix the flat spot problem. * @param fixFlatSopt True, if we should fix the flat spot problem. */ public void setFixFlatSopt(final boolean fixFlatSopt) { this.fixFlatSopt = fixFlatSopt; } /** * {@inheritDoc} */ @Override public void setIteration(final int iteration) { this.iterationCount = iteration; } /** * @return The batch size. */ public int getBatchSize() { return batchSize; } /** * Set the batch size. * @param batchSize The batch size. */ public void setBatchSize(int batchSize) { this.batchSize = batchSize; } }