/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.neural.networks.training.propagation; import java.util.Random; import org.encog.EncogError; import org.encog.engine.network.activation.ActivationFunction; import org.encog.engine.network.activation.ActivationSigmoid; import org.encog.mathutil.IntRange; import org.encog.ml.MLMethod; import org.encog.ml.TrainingImplementationType; import org.encog.ml.data.MLDataSet; import org.encog.ml.train.BasicTraining; import org.encog.neural.error.ErrorFunction; import org.encog.neural.error.LinearErrorFunction; import org.encog.neural.flat.FlatNetwork; import org.encog.neural.networks.ContainsFlat; import org.encog.neural.networks.training.BatchSize; import org.encog.neural.networks.training.Train; import org.encog.util.EncogValidate; import org.encog.util.EngineArray; import org.encog.util.concurrency.DetermineWorkload; import org.encog.util.concurrency.EngineConcurrency; import org.encog.util.concurrency.MultiThreadable; import org.encog.util.concurrency.TaskGroup; import org.encog.util.logging.EncogLogging; /** * Implements basic functionality that is needed by each of the propagation * methods. The specifics of each of the propagation methods is implemented * inside of the PropagationMethod interface implementors. * * @author jheaton * */ public abstract class Propagation extends BasicTraining implements Train, MultiThreadable, BatchSize, GradientWorkerOwner { /** * Used to generate randomness for dropout */ protected Random dropoutRandomSource = new Random(); /** * The Dropout rate, between 0 and 1 */ private double dropoutRate = 0; /** * The current flat network we are using for training, or null for none. */ private FlatNetwork currentFlatNetwork; /** * The number of threads to use. */ private int numThreads; /** * The gradients. */ protected double[] gradients; /** * The last gradients, from the last training iteration. */ private final double[] lastGradient; /** * The network to train. */ protected final ContainsFlat network; /** * The network in indexable form. */ private final MLDataSet indexable; /** * The workers. */ private GradientWorker[] workers; /** * The total error. Used to take the average of. */ private double totalError; /** * Reported exception from the threads. */ private Throwable reportedException; /** * The iteration. */ private int iteration; /** * The flat spot constants. */ private double[] flatSpot; /** * Should we fix flat spots. */ private boolean shouldFixFlatSpot; /** * The error function. */ private ErrorFunction ef = new LinearErrorFunction(); /** * The batch size. Specify 1 for pure online training. Specify 0 for pure * batch training (complete training set in one batch). Otherwise specify * the batch size for batch training. */ private int batchSize = 0; /** * How much to apply l1 regularization penalty, 0 (default) for none. */ private double l1; /** * How much to apply l2 regularization penalty, 0 (default) for none. */ private double l2; private boolean finalized = false; /** * Construct a propagation object. * * @param network * The network. * @param training * The training set. */ public Propagation(final ContainsFlat network, final MLDataSet training) { super(TrainingImplementationType.Iterative); this.network = network; this.currentFlatNetwork = network.getFlat(); setTraining(training); this.gradients = new double[this.currentFlatNetwork.getWeights().length]; this.lastGradient = new double[this.currentFlatNetwork.getWeights().length]; this.indexable = training; this.numThreads = 0; this.reportedException = null; this.shouldFixFlatSpot = true; } /** * Change the dropout rate * @param rate * The dropout rate. */ public void setDroupoutRate(double rate) { this.dropoutRate = rate; } /** * @return the current dropout rate * */ public double getDropoutRate() { return this.dropoutRate; } /** * Should be called after training has completed and the iteration method * will not be called any further. */ @Override public void finishTraining() { finishTraining(this.dropoutRate); } public void finishTraining(double dropoutRate) { if(!this.finalized) { final double[] weights = this.currentFlatNetwork.getWeights(); if(dropoutRate > 0) { for (int i = 0; i < weights.length; i++) { weights[i] *= (1 - dropoutRate); } } this.finalized = true; } super.finishTraining(); } /** * @return the currentFlatNetwork */ public FlatNetwork getCurrentFlatNetwork() { return this.currentFlatNetwork; } /** * {@inheritDoc} */ public MLMethod getMethod() { return this.network; } /** * Perform one training iteration. */ public void iteration() { iteration(1); } /** * Increase the iteration by one. */ public void rollIteration() { this.iteration++; } /** * Process as pure batch (size 0). Batch size equal to training set size. */ private void processPureBatch() { calculateGradients(); if (this.currentFlatNetwork.isLimited()) { learnLimited(); } else { learn(); } } private void processBatches() { if (this.workers == null) { init(); } if (this.currentFlatNetwork.getHasContext()) { this.workers[0].getNetwork().clearContext(); } this.workers[0].getErrorCalculation().reset(); int lastLearn = 0; for (int i = 0; i < this.getTraining().size(); i++) { this.workers[0].run(i); lastLearn++; if (lastLearn++ >= this.batchSize) { if (this.currentFlatNetwork.isLimited()) { learnLimited(); } else { learn(); lastLearn = 0; } } } // handle any remaining learning if( lastLearn>0 ) { learn(); } this.setError(this.workers[0].getErrorCalculation().calculate()); } /** * Perform the specified number of training iterations. This can be more * efficient than single training iterations. This is particularly true if * you are training with a GPU. * * @param count * The number of training iterations. */ @Override public void iteration(final int count) { try { for (int i = 0; i < count; i++) { preIteration(); rollIteration(); if (this.batchSize == 0) { processPureBatch(); } else { processBatches(); } for (final GradientWorker worker : this.workers) { EngineArray.arrayCopy(this.currentFlatNetwork.getWeights(), 0, worker.getWeights(), 0, this.currentFlatNetwork.getWeights().length); } if (this.currentFlatNetwork.getHasContext()) { copyContexts(); } if (this.reportedException != null) { throw (new EncogError(this.reportedException)); } postIteration(); EncogLogging.log(EncogLogging.LEVEL_INFO, "Training iteration done, error: " + getError()); } } catch (final ArrayIndexOutOfBoundsException ex) { EncogValidate.validateNetworkForTraining(this.network, getTraining()); throw new EncogError(ex); } } /** * Set the number of threads. Specify zero to tell Encog to automatically * determine the best number of threads for the processor. If OpenCL is used * as the target device, then this value is not used. * * @param numThreads * The number of threads. */ public void setThreadCount(final int numThreads) { this.numThreads = numThreads; } @Override public int getThreadCount() { return this.numThreads; } /** * Default is true. Call this with false to disable flat spot fix. * * For more info on flat spot: * * http://www.heatonresearch.com/wiki/Flat_Spot * * @param b * True to fix flat spots, false otherwise. */ public void fixFlatSpot(boolean b) { this.shouldFixFlatSpot = b; } public void setErrorFunction(ErrorFunction ef) { this.ef = ef; } /** * Calculate the gradients. */ public void calculateGradients() { if (this.workers == null) { init(); } if (this.currentFlatNetwork.getHasContext()) { this.workers[0].getNetwork().clearContext(); } this.totalError = 0; if (this.workers.length > 1) { final TaskGroup group = EngineConcurrency.getInstance() .createTaskGroup(); for (final GradientWorker worker : this.workers) { EngineConcurrency.getInstance().processTask(worker, group); } group.waitForComplete(); } else { this.workers[0].run(); } this.setError(this.totalError / this.workers.length); } /** * Copy the contexts to keep them consistent with multithreaded training. */ private void copyContexts() { // copy the contexts(layer outputO from each group to the next group for (int i = 0; i < (this.workers.length - 1); i++) { final double[] src = this.workers[i].getNetwork().getLayerOutput(); final double[] dst = this.workers[i + 1].getNetwork() .getLayerOutput(); EngineArray.arrayCopy(src, dst); } // copy the contexts from the final group to the real network EngineArray.arrayCopy(this.workers[this.workers.length - 1] .getNetwork().getLayerOutput(), this.currentFlatNetwork .getLayerOutput()); } /** * Init the process. */ private void init() { // fix flat spot, if needed this.flatSpot = new double[this.currentFlatNetwork .getActivationFunctions().length]; if (this.shouldFixFlatSpot) { for (int i = 0; i < this.currentFlatNetwork .getActivationFunctions().length; i++) { final ActivationFunction af = this.currentFlatNetwork .getActivationFunctions()[i]; if (af instanceof ActivationSigmoid) { this.flatSpot[i] = 0.1; } else { this.flatSpot[i] = 0.0; } } } else { EngineArray.fill(this.flatSpot, 0.0); } // setup workers // Do not use multi-threading for non-pure batch training. // // At some point it would be good to add multi-threading // for batch-sizes that are large enough. // // Multi-threading cannot be added for pure (size 1) // online training. if (this.batchSize != 0) { this.numThreads = 1; } final DetermineWorkload determine = new DetermineWorkload( this.numThreads, (int) this.indexable.getRecordCount()); int actualThreadCount = determine.getThreadCount(); this.workers = new GradientWorker[actualThreadCount]; int index = 0; for (final IntRange r : determine.calculateWorkers()) { this.workers[index++] = new GradientWorker( this.currentFlatNetwork.clone(), this, this.indexable.openAdditional(), r.getLow(), r.getHigh(), this.flatSpot, this.ef); } initOthers(); } /** * {@inheritDoc} */ @Override public void report(final double[] gradients, final double error, final Throwable ex) { synchronized (this) { if (ex == null) { for (int i = 0; i < gradients.length; i++) { this.gradients[i] += gradients[i]; } this.totalError += error; } else { this.reportedException = ex; } } } /** * Apply and learn. */ protected void learn() { final double[] weights = this.currentFlatNetwork.getWeights(); if(this.dropoutRate > 0) { for (int i = 0; i < this.gradients.length; i++) { weights[i] += updateWeight(this.gradients, this.lastGradient, i, this.dropoutRate); this.gradients[i] = 0; } } else { for (int i = 0; i < this.gradients.length; i++) { weights[i] += updateWeight(this.gradients, this.lastGradient, i); this.gradients[i] = 0; } } } /** * Apply and learn. This is the same as learn, but it checks to see if any * of the weights are below the limit threshold. In this case, these weights * are zeroed out. Having two methods allows the regular learn method, which * is what is usually use, to be as fast as possible. */ protected void learnLimited() { final double limit = this.currentFlatNetwork.getConnectionLimit(); final double[] weights = this.currentFlatNetwork.getWeights(); if(this.dropoutRate > 0) { for (int i = 0; i < this.gradients.length; i++) { if (Math.abs(weights[i]) < limit) { weights[i] = 0; } else { weights[i] += updateWeight(this.gradients, this.lastGradient, i, this.dropoutRate); } this.gradients[i] = 0; } } else { for (int i = 0; i < this.gradients.length; i++) { if (Math.abs(weights[i]) < limit) { weights[i] = 0; } else { weights[i] += updateWeight(this.gradients, this.lastGradient, i); } this.gradients[i] = 0; } } for (int i = 0; i < this.gradients.length; i++) { } } public abstract void initOthers(); /** * Update a weight, the means by which weights are updated vary depending on * the training. * * @param gradients * The gradients. * @param lastGradient * The last gradients. * @param index * The index. * @return The update value. */ public abstract double updateWeight(double[] gradients, double[] lastGradient, int index); /** * Update a weight using dropout, the means by which weights are updated vary depending on * the training. * * @param gradients * The gradients. * @param lastGradient * The last gradients. * @param index * The index. * @param dropoutRate * The dropout rate * @return The update value. */ public abstract double updateWeight(double[] gradients, double[] lastGradient, int index, double dropoutRate); /** * @return the lastGradient */ public double[] getLastGradient() { return lastGradient; } /** * {@inheritDoc} */ public int getBatchSize() { return this.batchSize; } /** * {@inheritDoc} */ public void setBatchSize(int theBatchSize) { this.batchSize = theBatchSize; } /** * {@inheritDoc} */ @Override public double getL1() { return l1; } /** * @param l1 How much to apply l1 regularization penalty, 0 (default) for none. */ public void setL1(double l1) { this.l1 = l1; } /** * {@inheritDoc} */ @Override public double getL2() { return l2; } /** * @param l2 How much to apply l2 regularization penalty, 0 (default) for none. */ public void setL2(double l2) { this.l2 = l2; } }