/** * Copyright 2010 Neuroph Project http://neuroph.sourceforge.net * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neuroph.nnet.flat; import java.util.HashSet; import java.util.Set; import org.encog.engine.EncogEngine; import org.encog.engine.EncogEngineError; import org.encog.engine.network.activation.ActivationFunction; import org.encog.engine.network.activation.ActivationLinear; import org.encog.engine.network.activation.ActivationSigmoid; import org.encog.engine.network.activation.ActivationTANH; import org.encog.engine.network.flat.FlatLayer; import org.encog.engine.network.flat.FlatNetwork; import org.neuroph.core.Connection; import org.neuroph.core.Layer; import org.neuroph.core.NeuralNetwork; import org.neuroph.core.Neuron; import org.neuroph.core.Weight; import org.neuroph.core.transfer.Linear; import org.neuroph.core.transfer.Sigmoid; import org.neuroph.core.transfer.Tanh; import org.neuroph.core.transfer.TransferFunction; import org.neuroph.nnet.MultiLayerPerceptron; import org.neuroph.nnet.comp.BiasNeuron; import org.neuroph.nnet.comp.InputNeuron; import org.neuroph.core.learning.TrainingData; import org.neuroph.util.plugins.PluginBase; /** * The FlatNetworkPlugin allows a Neuroph multi-layer-perceptron to make use of the Encog Engine. * The Encog Engine makes use of flat, array-based, networks to train and process in a very efficient * manner. * * To make use of this plug-in, call the flattenNeuralNetworkNetwork method with the desired neural * network. The network will now train with multithreaded RPROP. The network can still be used as a * normal Neuroph network, however, if you change the structure of the network, you should reflatten it. * * To remove the plugin, use the unFlattenNeuralNetworkNetwork method. * * If you wish to use OpenCL/GPU processing, call the initCL method. * * Finally, once you are ready to exit your application, you should call the shutdown * method to make sure that the threadpool and OpenCL resources have been properly shutdown. * * @author Jeff Heaton (http://www.jeffheaton.com) */ public class FlatNetworkPlugin extends PluginBase { /** * The serial ID. */ private static final long serialVersionUID = 1L; /** * The name of the plugin. */ public static final String PLUGIN_NAME = "FlatNetworkPlugin"; /** * The flat network being used. */ public FlatNetwork flatNetwork; /** * Construct a flat network plugin. * @param network The flat network. */ public FlatNetworkPlugin(FlatNetwork network) { super(PLUGIN_NAME); this.flatNetwork = network; } /** * @return The actual flat network being used. */ public FlatNetwork getFlatNetwork() { return this.flatNetwork; } /** * This method is used to install the driver into a neural network. * @param network The neural network to flatten. * @return True if the flatening was successful. */ public static boolean flattenNeuralNetworkNetwork(NeuralNetwork network) { if (network instanceof MultiLayerPerceptron) { return flattenMultiLayerPerceptron((MultiLayerPerceptron) network); } else // failed to flaten the network return false; } /** * Init the OpenCL. */ public static void initCL() { EncogEngine.getInstance().initCL(); } /** * Make sure that the threadpool and OpenCL are properly shutdown. */ public static void shutdown() { EncogEngine.getInstance().shutdown(); } /** * Flatten a multi-layer perceptron. * @param network The network to flatten. * @return True, if the network was successfully flattened. */ private static boolean flattenMultiLayerPerceptron( MultiLayerPerceptron network) { FlatLayer[] flatLayers = new FlatLayer[network.getLayers().size()]; int index = 0; for (Layer layer : network.getLayers()) { FlatLayer flatLayer = flattenLayer(layer); if (flatLayer == null) return false; flatLayers[index++] = flatLayer; } FlatNetwork flat = new FlatNetwork(flatLayers); FlatNetworkPlugin plugin = new FlatNetworkPlugin(flat); network.addPlugin(plugin); FlatNetworkLearning training = new FlatNetworkLearning(flat); network.setLearningRule(training); flattenWeights(flat, network); return true; } /** * Flatten the specified layer. * @param layer The layer to flatten. * @return The flat layer that represents the provided layer. */ private static FlatLayer flattenLayer(Layer layer) { boolean inputLayer = false; Set<Class<?>> transferFunctions = new HashSet<Class<?>>(); int neuronCount = 0; int biasCount = 0; TransferFunction transfer = null; for (Neuron neuron : layer.getNeurons()) { if (neuron.getClass() == InputNeuron.class) inputLayer = true; if (neuron.getClass() == Neuron.class || neuron.getClass() == InputNeuron.class) { neuronCount++; transfer = neuron.getTransferFunction(); transferFunctions.add(transfer.getClass()); } else if (neuron.getClass() == BiasNeuron.class) biasCount++; } if (transferFunctions.size() > 1) return null; Class<?> t = transferFunctions.iterator().next(); double slope = 1; ActivationFunction activation = null; if (inputLayer) activation = new ActivationLinear(); else if (t == Linear.class) { slope = ((Linear) transfer).getSlope(); activation = new ActivationLinear(); } else if (t == Sigmoid.class) { slope = ((Sigmoid) transfer).getSlope(); activation = new ActivationSigmoid(); } else if (t == Tanh.class) { slope = ((Tanh) transfer).getSlope(); activation = new ActivationTANH(); } else return null; if (biasCount > 1) return null; double[] params = { slope }; return new FlatLayer(activation, neuronCount, biasCount == 1 ? 1.0:0.0, params); } /** * Replace all of the weights with FlatWeights. * @param flatNetwork The flat network. * @param network The neural network. */ private static void flattenWeights(FlatNetwork flatNetwork, NeuralNetwork network) { double[] weights = flatNetwork.getWeights(); int index = 0; for (int layerIndex = network.getLayers().size() - 1; layerIndex > 0; layerIndex--) { Layer layer = network.getLayers().get(layerIndex); for (Neuron neuron : layer.getNeurons()) { for (Connection connection : neuron.getInputConnections()) { if (index >= weights.length) throw new EncogEngineError("Weight size mismatch."); Weight weight = connection.getWeight(); FlatWeight flatWeight = new FlatWeight(weights, index++); flatWeight.setValue(weight.getValue()); connection.setWeight(flatWeight); } } } } /** * Remove the flat network plugin, and replace flat weights with regular Neuroph weights. * * @param network The network to unflatten. * @return True if unflattening was successful. */ public static boolean unFlattenNeuralNetworkNetwork(NeuralNetwork network) { for (int layerIndex = network.getLayers().size() - 1; layerIndex > 0; layerIndex--) { Layer layer = network.getLayers().get(layerIndex); for (Neuron neuron : layer.getNeurons()) { for (Connection connection : neuron.getInputConnections()) { Weight weight = connection.getWeight(); if (weight instanceof FlatWeight) { Weight weight2 = new Weight(weight.getValue()); //weight2.setPreviousValue(weight.getPreviousValue()); weight2.getTrainingData().set(TrainingData.PREVIOUS_WEIGHT, weight.getTrainingData().get(TrainingData.PREVIOUS_WEIGHT)); connection.setWeight(weight2); } } } } network.removePlugin(FlatNetworkPlugin.PLUGIN_NAME); return true; } }