/* * Encog(tm) Core v2.5 - Java Version * http://www.heatonresearch.com/encog/ * http://code.google.com/p/encog-java/ * Copyright 2008-2010 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.neural.networks.training.competitive; import java.util.Collection; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import org.encog.engine.util.Format; import org.encog.mathutil.matrices.Matrix; import org.encog.neural.data.NeuralData; import org.encog.neural.data.NeuralDataPair; import org.encog.neural.data.NeuralDataSet; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.layers.Layer; import org.encog.neural.networks.structure.FlatUpdateNeeded; import org.encog.neural.networks.synapse.Synapse; import org.encog.neural.networks.training.BasicTraining; import org.encog.neural.networks.training.LearningRate; import org.encog.neural.networks.training.competitive.neighborhood.NeighborhoodFunction; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * This class implements competitive training, which would be used in a * winner-take-all neural network, such as the self organizing map (SOM). This * is an unsupervised training method, no ideal data is needed on the training * set. If ideal data is provided, it will be ignored. * * Training is done by looping over all of the training elements and calculating * a "best matching unit" (BMU). This BMU output neuron is then adjusted to * better "learn" this pattern. Additionally, this training may be applied to * othr "nearby" output neurons. The degree to which nearby neurons are update * is defined by the neighborhood function. * * A neighborhood function is required to determine the degree to which * neighboring neurons (to the winning neuron) are updated by each training * iteration. * * Because this is unsupervised training, calculating an error to measure * progress by is difficult. The error is defined to be the "worst", or longest, * Euclidean distance of any of the BMU's. This value should be minimized, as * learning progresses. * * Because only the BMU neuron and its close neighbors are updated, you can end * up with some output neurons that learn nothing. By default these neurons are * not forced to win patterns that are not represented well. This spreads out * the workload among all output neurons. This feature is not used by default, * but can be enabled by setting the "forceWinner" property. * * @author jheaton * */ public class CompetitiveTraining extends BasicTraining implements LearningRate { /** * The neighborhood function to use to determine to what degree a neuron * should be "trained". */ private final NeighborhoodFunction neighborhood; /** * The learning rate. To what degree should changes be applied. */ private double learningRate; /** * The network being trained. */ private final BasicNetwork network; /** * The input layer. */ private final Layer inputLayer; /** * The output layer. */ private final Layer outputLayer; /** * A collection of the synapses being modified. */ private final Collection<Synapse> synapses; /** * How many neurons in the input layer. */ private final int inputNeuronCount; /** * How many neurons in the output layer. */ private final int outputNeuronCount; /** * Utility class used to determine the BMU. */ private final BestMatchingUnit bmuUtil; /** * Holds the corrections for any matrix being trained. */ private final Map<Synapse, Matrix> correctionMatrix = new HashMap<Synapse, Matrix>(); /** * True is a winner is to be forced, see class description, or forceWinners * method. By default, this is true. */ private boolean forceWinner; /** * When used with autodecay, this is the starting learning rate. */ private double startRate; /** * When used with autodecay, this is the ending learning rate. */ private double endRate; /** * When used with autodecay, this is the starting radius. */ private double startRadius; /** * When used with autodecay, this is the ending radius. */ private double endRadius; /** * This is the current autodecay learning rate. */ private double autoDecayRate; /** * This is the current autodecay radius. */ private double autoDecayRadius; /** * The logging object. */ private final Logger logger = LoggerFactory.getLogger(this.getClass()); /** * The current radius. */ private double radius; /** * Create an instance of competitive training. * * @param network * The network to train. * @param learningRate * The learning rate, how much to apply per iteration. * @param training * The training set (unsupervised). * @param neighborhood * The neighborhood function to use. */ public CompetitiveTraining(final BasicNetwork network, final double learningRate, final NeuralDataSet training, final NeighborhoodFunction neighborhood) { this.neighborhood = neighborhood; setTraining(training); this.learningRate = learningRate; this.network = network; this.inputLayer = network.getLayer(BasicNetwork.TAG_INPUT); this.outputLayer = network.getLayer(BasicNetwork.TAG_OUTPUT); this.synapses = network.getStructure().getPreviousSynapses( this.outputLayer); this.inputNeuronCount = this.inputLayer.getNeuronCount(); this.outputNeuronCount = this.outputLayer.getNeuronCount(); this.forceWinner = false; setError(0); // setup the correction matrix for (final Synapse synapse : this.synapses) { final Matrix matrix = new Matrix(synapse.getMatrix().getRows(), synapse.getMatrix().getCols()); this.correctionMatrix.put(synapse, matrix); } // create the BMU class this.bmuUtil = new BestMatchingUnit(this); } /** * Loop over the synapses to be trained and apply any corrections that were * determined by this training iteration. */ private void applyCorrection() { for (final Entry<Synapse, Matrix> entry : this.correctionMatrix .entrySet()) { entry.getKey().getMatrix().set(entry.getValue()); } this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten); } /** * Should be called each iteration if autodecay is desired. */ public void autoDecay() { if (this.radius > this.endRadius) { this.radius += this.autoDecayRadius; } if (this.learningRate > this.endRate) { this.learningRate += this.autoDecayRate; } getNeighborhood().setRadius(this.radius); } /** * Copy the specified input pattern to the weight matrix. This causes an * output neuron to learn this pattern "exactly". This is useful when a * winner is to be forced. * * @param synapse * The synapse that is the target of the copy. * @param outputNeuron * The output neuron to set. * @param input * The input pattern to copy. */ private void copyInputPattern(final Synapse synapse, final int outputNeuron, final NeuralData input) { for (int inputNeuron = 0; inputNeuron < this.inputNeuronCount; inputNeuron++) { synapse.getMatrix().set(inputNeuron, outputNeuron, input.getData(inputNeuron)); } } /** * Called to decay the learning rate and radius by the specified amount. * @param d The percent to decay by. */ public void decay(final double d) { this.radius *= (1.0 - d); this.learningRate *= (1.0 - d); } /** * Decay the learning rate and radius by the specified amount. * @param decayRate The percent to decay the learning rate by. * @param decayRadius The percent to decay the radius by. */ public void decay(final double decayRate, final double decayRadius) { this.radius *= (1.0 - decayRadius); this.learningRate *= (1.0 - decayRate); getNeighborhood().setRadius(this.radius); } /** * Determine the weight adjustment for a single neuron during a training * iteration. * * @param weight * The starting weight. * @param input * The input to this neuron. * @param currentNeuron * The neuron who's weight is being updated. * @param bmu * The neuron that "won", the best matching unit. * @return The new weight value. */ private double determineNewWeight(final double weight, final double input, final int currentNeuron, final int bmu) { final double newWeight = weight + (this.neighborhood.function(currentNeuron, bmu) * this.learningRate * (input - weight)); return newWeight; } /** * Force any neurons that did not win to off-load patterns from overworked * neurons. * * @param won * An array that specifies how many times each output neuron has * "won". * @param leastRepresented * The training pattern that is the least represented by this * neural network. * @param synapse * The synapse to modify. * @return True if a winner was forced. */ private boolean forceWinners(final Synapse synapse, final int[] won, final NeuralData leastRepresented) { double maxActivation = Double.MIN_VALUE; int maxActivationNeuron = -1; final NeuralData output = this.network.compute(leastRepresented); // Loop over all of the output neurons. Consider any neurons that were // not the BMU (winner) for any pattern. Track which of these // non-winning neurons had the highest activation. for (int outputNeuron = 0; outputNeuron < won.length; outputNeuron++) { // Only consider neurons that did not "win". if (won[outputNeuron] == 0) { if ((maxActivationNeuron == -1) || (output.getData(outputNeuron) > maxActivation)) { maxActivation = output.getData(outputNeuron); maxActivationNeuron = outputNeuron; } } } // If a neurons was found that did not activate for any patterns, then // force it to "win" the least represented pattern. if (maxActivationNeuron != -1) { copyInputPattern(synapse, maxActivationNeuron, leastRepresented); return true; } else { return false; } } /** * @return The input neuron count. */ public int getInputNeuronCount() { return this.inputNeuronCount; } /** * @return The learning rate. This was set when the object was created. */ public double getLearningRate() { return this.learningRate; } /** * @return The network neighborhood function. */ public NeighborhoodFunction getNeighborhood() { return this.neighborhood; } /** * @return The network being trained. */ public BasicNetwork getNetwork() { return this.network; } /** * @return The output neuron count. */ public int getOutputNeuronCount() { return this.outputNeuronCount; } /** * @return Is a winner to be forced of neurons that do not learn. See class * description for more info. */ public boolean isForceWinner() { return this.forceWinner; } /** * Perform one training iteration. */ public void iteration() { if (this.logger.isInfoEnabled()) { this.logger.info("Performing Competitive Training iteration."); } preIteration(); // Reset the BMU and begin this iteration. this.bmuUtil.reset(); final int[] won = new int[this.outputNeuronCount]; double leastRepresentedActivation = Double.MAX_VALUE; NeuralData leastRepresented = null; // The synapses are processed parallel to each other. for (final Synapse synapse : this.synapses) { // Reset the correction matrix for this synapse and iteration. final Matrix correction = this.correctionMatrix.get(synapse); correction.clear(); // Determine the BMU for each training element. for (final NeuralDataPair pair : getTraining()) { final NeuralData input = pair.getInput(); final int bmu = this.bmuUtil.calculateBMU(synapse, input); // If we are to force a winner each time, then track how many // times each output neuron becomes the BMU (winner). if (this.forceWinner) { won[bmu]++; // Get the "output" from the network for this pattern. This // gets the activation level of the BMU. final NeuralData output = this.network.compute(pair .getInput()); // Track which training entry produces the least BMU. This // pattern is the least represented by the network. if (output.getData(bmu) < leastRepresentedActivation) { leastRepresentedActivation = output.getData(bmu); leastRepresented = pair.getInput(); } } train(bmu, synapse, input); } if (this.forceWinner) { // force any non-winning neurons to share the burden somewhat\ if (!forceWinners(synapse, won, leastRepresented)) { applyCorrection(); } } else { applyCorrection(); } } this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten); // update the error setError(this.bmuUtil.getWorstDistance()); postIteration(); } /** * Setup autodecay. This will decrease the radius and learning rate from * the start values to the end values. * @param plannedIterations The number of iterations that are planned. * This allows the decay rate to be determined. * @param startRate The starting learning rate. * @param endRate The ending learning rate. * @param startRadius The starting radius. * @param endRadius The ending radius. */ public void setAutoDecay(final int plannedIterations, final double startRate, final double endRate, final double startRadius, final double endRadius) { this.startRate = startRate; this.endRate = endRate; this.startRadius = startRadius; this.endRadius = endRadius; this.autoDecayRadius = (endRadius - startRadius) / plannedIterations; this.autoDecayRate = (endRate - startRate) / plannedIterations; setParams(this.startRate, this.startRadius); } /** * Determine if a winner is to be forced. See class description for more * info. * * @param forceWinner * True if a winner is to be forced. */ public void setForceWinner(final boolean forceWinner) { this.forceWinner = forceWinner; } /** * Set the learning rate. This is the rate at which the weights are changed. * * @param rate * The learning rate. */ public void setLearningRate(final double rate) { this.learningRate = rate; } /** * Set the learning rate and radius. * @param rate The new learning rate. * @param radius The new radius. */ public void setParams(final double rate, final double radius) { this.radius = radius; this.learningRate = rate; getNeighborhood().setRadius(radius); } /** * {@inheritDoc} */ @Override public String toString() { final StringBuilder result = new StringBuilder(); result.append("Rate="); result.append(Format.formatPercent(this.learningRate)); result.append(", Radius="); result.append(Format.formatDouble(this.radius, 2)); return result.toString(); } /** * Train for the specified synapse and BMU. * * @param bmu * The best matching unit for this input. * @param synapse * The synapse to train. * @param input * The input to train for. */ private void train(final int bmu, final Synapse synapse, final NeuralData input) { // adjust the weight for the BMU and its neighborhood for (int outputNeuron = 0; outputNeuron < this.outputNeuronCount; outputNeuron++) { trainPattern(synapse, input, outputNeuron, bmu); } } /** * Train the specified pattern. Find a winning neuron and adjust all * neurons according to the neighborhood function. * @param pattern The pattern to train. */ public void trainPattern(final NeuralData pattern) { for (final Synapse synapse : this.synapses) { final NeuralData input = pattern; final int bmu = this.bmuUtil.calculateBMU(synapse, input); train(bmu, synapse, input); } applyCorrection(); } /** * Train for the specified pattern. * * @param synapse * The synapse to train. * @param input * The input pattern to train for. * @param current * The current output neuron being trained. * @param bmu * The best matching unit, or winning output neuron. */ private void trainPattern(final Synapse synapse, final NeuralData input, final int current, final int bmu) { final Matrix correction = this.correctionMatrix.get(synapse); for (int inputNeuron = 0; inputNeuron < this.inputNeuronCount; inputNeuron++) { final double currentWeight = synapse.getMatrix().get(inputNeuron, current); final double inputValue = input.getData(inputNeuron); final double newWeight = determineNewWeight(currentWeight, inputValue, current, bmu); correction.set(inputNeuron, current, newWeight); } } }