/* * Encog(tm) Core v2.5 - Java Version * http://www.heatonresearch.com/encog/ * http://code.google.com/p/encog-java/ * Copyright 2008-2010 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.neural.prune; import java.util.Collection; import org.encog.mathutil.matrices.Matrix; import org.encog.mathutil.matrices.MatrixMath; import org.encog.mathutil.randomize.Distort; import org.encog.neural.NeuralNetworkError; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.layers.Layer; import org.encog.neural.networks.synapse.Synapse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Prune a neural network selectively. This class allows you to either add or * remove neurons from layers of a neural network. Tools * * @author jheaton * */ public class PruneSelective { /** * */ private final BasicNetwork network; /** * The logging object. */ @SuppressWarnings("unused") private final Logger logger = LoggerFactory.getLogger(this.getClass()); /** * Construct an object prune the neural network. * * @param network * The network to prune. */ public PruneSelective(final BasicNetwork network) { this.network = network; } /** * Change the neuron count for the network. If the count is increased then a * zero-weighted neuron is added, which will not affect the output of the * neural network. If the neuron count is decreased, then the weakest neuron * will be removed. * * @param layer * The layer to adjust. * @param neuronCount * The new neuron count for this layer. */ public void changeNeuronCount(final Layer layer, final int neuronCount) { if (neuronCount == 0) { throw new NeuralNetworkError("Can't decrease to zero neurons."); } // is there anything to do? if (neuronCount == layer.getNeuronCount()) { return; } if (neuronCount > layer.getNeuronCount()) { increaseNeuronCount(layer, neuronCount); } else { decreaseNeuronCount(layer, neuronCount); } } /** * Internal function to decrease the neuron count of a layer. * * @param layer * The layer to affect. * @param neuronCount * The new neuron count. */ private void decreaseNeuronCount(final Layer layer, final int neuronCount) { // create an array to hold the least significant neurons, which will be // removed final int lostNeuronCount = layer.getNeuronCount() - neuronCount; final int[] lostNeuron = findWeakestNeurons(layer, lostNeuronCount); // finally, actually prune the neurons that the previous steps // determined to remove for (int i = 0; i < lostNeuronCount; i++) { prune(layer, lostNeuron[i] - i); } } /** * Determine the significance of the neuron. The higher the return value, * the more significant the neuron is. * * @param layer * The layer to query. * @param neuron * The neuron to query. * @return How significant is this neuron. */ public double determineNeuronSignificance(final Layer layer, final int neuron) { // calculate the bias significance double result = 0; if (layer.hasBias()) { result += layer.getBiasWeight(neuron); } // calculate the outbound significance for (final Synapse synapse : layer.getNext()) { for (int i = 0; i < synapse.getToNeuronCount(); i++) { result += synapse.getMatrix().get(neuron, i); } } // calculate the bias significance final Collection<Synapse> inboundSynapses = this.network.getStructure() .getPreviousSynapses(layer); for (final Synapse synapse : inboundSynapses) { if (synapse.getMatrix() != null) { for (int i = 0; i < synapse.getFromNeuronCount(); i++) { result += synapse.getMatrix().get(i, neuron); } } } return Math.abs(result); } /** * Find the weakest neurons on a layer. Considers both weight and bias. * @param layer The layer to search. * @param count The number of neurons to find. * @return An array of the indexes of the weakest neurons. */ private int[] findWeakestNeurons(final Layer layer, final int count) { // create an array to hold the least significant neurons, which will be // returned final double[] lostNeuronSignificance = new double[count]; final int[] lostNeuron = new int[count]; // init the potential lost neurons to the first ones, we will find // better choices if we can for (int i = 0; i < count; i++) { lostNeuron[i] = i; lostNeuronSignificance[i] = determineNeuronSignificance(layer, i); } // now loop over the remaining neurons and see if any are better ones to // remove for (int i = count; i < layer.getNeuronCount(); i++) { final double significance = determineNeuronSignificance(layer, i); // is this neuron less significant than one already chosen? for (int j = 0; j < count; j++) { if (lostNeuronSignificance[j] > significance) { lostNeuron[j] = i; lostNeuronSignificance[j] = significance; break; } } } return lostNeuron; } /** * @return The network that is being processed. */ public BasicNetwork getNetwork() { return this.network; } /** * Internal function to increase the neuron count. This will add a * zero-weight neuron to this layer. * * @param layer * The layer to increase. * @param neuronCount * The new neuron count. */ private void increaseNeuronCount(final Layer layer, final int neuronCount) { // adjust the bias final double[] newBias = new double[neuronCount]; if (layer.hasBias()) { for (int i = 0; i < layer.getNeuronCount(); i++) { newBias[i] = layer.getBiasWeight(i); } layer.setBiasWeights(newBias); } // adjust the outbound weight matrixes for (final Synapse synapse : layer.getNext()) { final Matrix newMatrix = new Matrix(neuronCount, synapse.getToNeuronCount()); // copy existing matrix to new matrix for (int row = 0; row < layer.getNeuronCount(); row++) { for (int col = 0; col < synapse.getToNeuronCount(); col++) { newMatrix.set(row, col, synapse.getMatrix().get(row, col)); } } synapse.setMatrix(newMatrix); } // adjust the inbound weight matrixes final Collection<Synapse> inboundSynapses = this.network.getStructure() .getPreviousSynapses(layer); for (final Synapse synapse : inboundSynapses) { if (synapse.getMatrix() != null) { final Matrix newMatrix = new Matrix( synapse.getFromNeuronCount(), neuronCount); // copy existing matrix to new matrix for (int row = 0; row < synapse.getFromNeuronCount(); row++) { for (int col = 0; col < synapse.getToNeuronCount(); col++) { newMatrix.set(row, col, synapse.getMatrix().get(row, col)); } } synapse.setMatrix(newMatrix); } } // adjust the bias if (layer.hasBias()) { final double[] newBias2 = new double[neuronCount]; for (int i = 0; i < layer.getNeuronCount(); i++) { newBias2[i] = layer.getBiasWeight(i); } layer.setBiasWeights(newBias2); } // finally, up the neuron count layer.setNeuronCount(neuronCount); } /** * Prune one of the neurons from this layer. Remove all entries in this * weight matrix and other layers. * * @param targetLayer * The neuron to prune. Zero specifies the first neuron. * @param neuron * The neuron to prune. */ public void prune(final Layer targetLayer, final int neuron) { // delete a row on this matrix for (final Synapse synapse : targetLayer.getNext()) { synapse.setMatrix(MatrixMath.deleteRow(synapse.getMatrix(), neuron)); } // delete a column on the previous final Collection<Layer> previous = this.network.getStructure() .getPreviousLayers(targetLayer); for (final Layer prevLayer : previous) { if (previous != null) { for (final Synapse synapse : prevLayer.getNext()) { if (synapse.getMatrix() != null) { synapse.setMatrix(MatrixMath.deleteCol( synapse.getMatrix(), neuron)); } } } } // remove the bias if (targetLayer.hasBias()) { final double[] newBias = new double[targetLayer.getNeuronCount() - 1]; int targetIndex = 0; for (int i = 0; i < targetLayer.getNeuronCount(); i++) { if (i != neuron) { newBias[targetIndex++] = targetLayer.getBiasWeight(i); } } targetLayer.setBiasWeights(newBias); } // update the neuron count targetLayer.setNeuronCount(targetLayer.getNeuronCount() - 1); } /** * Stimulate the specified neuron by the specified percent. This is used to * randomize the weights and bias values for weak neurons. * * @param percent * The percent to randomize by. * @param layer * The layer that the neuron is on. * @param neuron * The neuron to randomize. */ public void stimulateNeuron(final double percent, final Layer layer, final int neuron) { final Distort d = new Distort(percent); if (layer.hasBias()) { layer.setBiasWeight(neuron, d.randomize(layer.getBiasWeight(neuron))); } // calculate the outbound significance for (final Synapse synapse : layer.getNext()) { for (int i = 0; i < synapse.getToNeuronCount(); i++) { final double v = synapse.getMatrix().get(neuron, i); synapse.getMatrix().set(neuron, i, d.randomize(v)); } } final Collection<Synapse> inboundSynapses = this.network.getStructure() .getPreviousSynapses(layer); for (final Synapse synapse : inboundSynapses) { for (int i = 0; i < synapse.getFromNeuronCount(); i++) { final double v = synapse.getMatrix().get(i, neuron); synapse.getMatrix().set(i, neuron, d.randomize(v)); } } } /** * Stimulate weaker neurons on a layer. Find the weakest neurons and then * randomize them by the specified percent. * * @param layer * The layer to stimulate. * @param count * The number of weak neurons to stimulate. * @param percent * The percent to stimulate by. */ public void stimulateWeakNeurons(final Layer layer, final int count, final double percent) { final int[] weak = findWeakestNeurons(layer, count); for (final int element : weak) { stimulateNeuron(percent, layer, element); } } }