/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.neural.freeform.training; import java.io.Serializable; import org.encog.mathutil.EncogMath; import org.encog.ml.data.MLDataSet; import org.encog.neural.freeform.FreeformConnection; import org.encog.neural.freeform.FreeformNetwork; import org.encog.neural.freeform.task.ConnectionTask; import org.encog.neural.networks.training.propagation.TrainingContinuation; import org.encog.neural.networks.training.propagation.resilient.RPROPConst; public class FreeformResilientPropagation extends FreeformPropagationTraining implements Serializable { /** * The serial ID. */ private static final long serialVersionUID = 1L; /** * Temp value #0: the gradient. */ public static final int TEMP_GRADIENT = 0; /** * Temp value #1: the last gradient. */ public static final int TEMP_LAST_GRADIENT = 1; /** * Temp value #2: the update. */ public static final int TEMP_UPDATE = 2; /** * Temp value #3: the the last weight delta. */ public static final int TEMP_LAST_WEIGHT_DELTA = 3; /** * The max step. */ private final double maxStep; /** * Construct the RPROP trainer, Use default intiial update and max step. * @param theNetwork The network to train. * @param theTraining The training set. */ public FreeformResilientPropagation(final FreeformNetwork theNetwork, final MLDataSet theTraining) { this(theNetwork, theTraining, RPROPConst.DEFAULT_INITIAL_UPDATE, RPROPConst.DEFAULT_MAX_STEP); } /** * Construct the RPROP trainer. * @param theNetwork The network to train. * @param theTraining The training set. * @param initialUpdate The initial update. * @param theMaxStep The max step. */ public FreeformResilientPropagation(final FreeformNetwork theNetwork, final MLDataSet theTraining, final double initialUpdate, final double theMaxStep) { super(theNetwork, theTraining); this.maxStep = theMaxStep; theNetwork.tempTrainingAllocate(1, 4); theNetwork.performConnectionTask(new ConnectionTask() { @Override public void task(final FreeformConnection c) { c.setTempTraining(FreeformResilientPropagation.TEMP_UPDATE, initialUpdate); } }); } /** * {@inheritDoc} */ @Override protected void learnConnection(final FreeformConnection connection) { // multiply the current and previous gradient, and take the // sign. We want to see if the gradient has changed its sign. final int change = EncogMath .sign(connection .getTempTraining(FreeformResilientPropagation.TEMP_GRADIENT) * connection .getTempTraining(FreeformResilientPropagation.TEMP_LAST_GRADIENT)); double weightChange = 0; // if the gradient has retained its sign, then we increase the // delta so that it will converge faster if (change > 0) { double delta = connection .getTempTraining(FreeformResilientPropagation.TEMP_UPDATE) * RPROPConst.POSITIVE_ETA; delta = Math.min(delta, this.maxStep); weightChange = EncogMath .sign(connection .getTempTraining(FreeformResilientPropagation.TEMP_GRADIENT)) * delta; connection.setTempTraining( FreeformResilientPropagation.TEMP_UPDATE, delta); connection .setTempTraining( FreeformResilientPropagation.TEMP_LAST_GRADIENT, connection .getTempTraining(FreeformResilientPropagation.TEMP_GRADIENT)); } else if (change < 0) { // if change<0, then the sign has changed, and the last // delta was too big double delta = connection .getTempTraining(FreeformResilientPropagation.TEMP_UPDATE) * RPROPConst.NEGATIVE_ETA; delta = Math.max(delta, RPROPConst.DELTA_MIN); connection.setTempTraining( FreeformResilientPropagation.TEMP_UPDATE, delta); weightChange = -connection .getTempTraining(FreeformResilientPropagation.TEMP_LAST_WEIGHT_DELTA); // set the previous gradient to zero so that there will be no // adjustment the next iteration connection.setTempTraining( FreeformResilientPropagation.TEMP_LAST_GRADIENT, 0); } else if (change == 0) { // if change==0 then there is no change to the delta final double delta = connection .getTempTraining(FreeformResilientPropagation.TEMP_UPDATE); weightChange = EncogMath .sign(connection .getTempTraining(FreeformResilientPropagation.TEMP_GRADIENT)) * delta; connection .setTempTraining( FreeformResilientPropagation.TEMP_LAST_GRADIENT, connection .getTempTraining(FreeformResilientPropagation.TEMP_GRADIENT)); } // apply the weight change, if any connection.addWeight(weightChange); connection.setTempTraining( FreeformResilientPropagation.TEMP_LAST_WEIGHT_DELTA, weightChange); } /** * {@inheritDoc} */ @Override public TrainingContinuation pause() { // TODO Auto-generated method stub return null; } /** * {@inheritDoc} */ @Override public void resume(final TrainingContinuation state) { // TODO Auto-generated method stub } }