/*
* Encog(tm) Core v2.5 - Java Version
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2010 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.neural.networks.structure;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.flat.FlatLayer;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.flat.FlatNetworkRBF;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ObjectPair;
import org.encog.mathutil.matrices.Matrix;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.layers.ContextLayer;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.layers.RadialBasisFunctionLayer;
import org.encog.neural.networks.logic.FeedforwardLogic;
import org.encog.neural.networks.logic.SimpleRecurrentLogic;
import org.encog.neural.networks.synapse.Synapse;
import org.encog.util.obj.ReflectionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Holds "cached" information about the structure of the neural network. This is
* a very good performance boost since the neural network does not need to
* traverse itself each time a complete collection of layers or synapses is
* needed.
*
* @author jheaton
*
*/
public class NeuralStructure implements Serializable {
/**
* The serial ID.
*/
private static final long serialVersionUID = -2929683885395737817L;
/**
* The logging object.
*/
private static final transient Logger LOGGER = LoggerFactory
.getLogger(NeuralStructure.class);
/**
* The layers in this neural network.
*/
private final List<Layer> layers = new ArrayList<Layer>();
/**
* The synapses in this neural network.
*/
private final List<Synapse> synapses = new ArrayList<Synapse>();
/**
* The neural network this class belongs to.
*/
private final BasicNetwork network;
/**
* The limit, below which a connection is treated as zero.
*/
private double connectionLimit;
/**
* Are connections limited?
*/
private boolean connectionLimited;
/**
* The next ID to be assigned to a layer.
*/
private int nextID = 1;
/**
* The flattened form of the network.
*/
private transient FlatNetwork flat;
/**
* What type of update is needed to the flat network.
*/
private transient FlatUpdateNeeded flatUpdate;
/**
* Construct a structure object for the specified network.
*
* @param network
* The network to construct a structure for.
*/
public NeuralStructure(final BasicNetwork network) {
this.network = network;
this.flatUpdate = FlatUpdateNeeded.None;
}
/**
* Assign an ID to every layer that does not already have one.
*/
public void assignID() {
for (final Layer layer : this.layers) {
assignID(layer);
}
sort();
}
/**
* Assign an ID to the specified layer.
*
* @param layer
* The layer to get an ID assigned.
*/
public void assignID(final Layer layer) {
if (layer.getID() == -1) {
layer.setID(getNextID());
}
}
/**
* Calculate the size that an array should be to hold all of the weights and
* bias values.
*
* @return The size of the calculated array.
*/
public int calculateSize() {
return NetworkCODEC.networkSize(this.network);
}
/**
* Determine if the network contains a layer of the specified type.
*
* @param type
* The layer type we are looking for.
* @return True if this layer type is present.
*/
public boolean containsLayerType(final Class<?> type) {
for (final Layer layer : this.layers) {
if (ReflectionUtil.isInstanceOf(layer.getClass(), type)) {
return true;
}
}
return false;
}
/**
* Count the number of non-context layers.
*
* @return The number of non-context layers.
*/
private int countNonContext() {
int result = 0;
for (final Layer layer : this.getLayers()) {
if (layer.getClass() != ContextLayer.class) {
result++;
}
}
return result;
}
/**
* Enforce that all connections are above the connection limit. Any
* connections below this limit will be severed.
*/
public void enforceLimit() {
if (!this.connectionLimited) {
return;
}
for (final Synapse synapse : this.synapses) {
final Matrix matrix = synapse.getMatrix();
if (matrix != null) {
for (int row = 0; row < matrix.getRows(); row++) {
for (int col = 0; col < matrix.getCols(); col++) {
final double value = matrix.get(row, col);
if (Math.abs(value) < this.connectionLimit) {
matrix.set(row, col, 0);
}
}
}
}
}
}
/**
* Build the layer structure.
*/
private void finalizeLayers() {
// no bias values on the input layer for feedforward/srn
if ((this.network.getLogic().getClass() == FeedforwardLogic.class)
|| (this.network.getLogic().getClass() == SimpleRecurrentLogic.class)) {
final Layer inputLayer = this.network
.getLayer(BasicNetwork.TAG_INPUT);
inputLayer.setBiasWeights(null);
}
final List<Layer> result = new ArrayList<Layer>();
this.layers.clear();
for (final Layer layer : this.network.getLayerTags().values()) {
getLayers(result, layer);
}
this.layers.addAll(result);
// make sure that the current ID is not going to cause a repeat
for (final Layer layer : this.layers) {
if (layer.getID() >= this.nextID) {
this.nextID = layer.getID() + 1;
}
}
sort();
}
/**
* Parse/finalize the limit value for connections.
*/
private void finalizeLimit() {
// see if there is a connection limit imposed
final String limit = this.network
.getPropertyString(BasicNetwork.TAG_LIMIT);
if (limit != null) {
try {
this.connectionLimited = true;
this.connectionLimit = Double.parseDouble(limit);
} catch (final NumberFormatException e) {
throw new NeuralNetworkError("Invalid property("
+ BasicNetwork.TAG_LIMIT + "):" + limit);
}
} else {
this.connectionLimited = false;
this.connectionLimit = 0;
}
}
/**
* Build the synapse and layer structure. This method should be called after
* you are done adding layers to a network, or change the network's logic
* property.
*/
public void finalizeStructure() {
finalizeLayers();
finalizeSynapses();
finalizeLimit();
Collections.sort(this.layers);
assignID();
this.network.getLogic().init(this.network);
enforceLimit();
flatten();
}
/**
* Build the synapse structure.
*/
private void finalizeSynapses() {
final Set<Synapse> result = new HashSet<Synapse>();
for (final Layer layer : getLayers()) {
for (final Synapse synapse : layer.getNext()) {
result.add(synapse);
}
}
this.synapses.clear();
this.synapses.addAll(result);
}
/**
* Find the next bias.
*
* @param layer
* The layer to search from.
* @return The next bias.
*/
private double findNextBias(final Layer layer) {
double bias = FlatNetwork.NO_BIAS_ACTIVATION;
if (layer.getNext().size() > 0) {
final Synapse synapse = this.network.getStructure()
.findNextSynapseByLayerType(layer, BasicLayer.class);
if (synapse != null) {
final Layer nextLayer = synapse.getToLayer();
if (nextLayer.hasBias()) {
bias = nextLayer.getBiasActivation();
}
}
}
return bias;
}
/**
* Find the next synapse by layer type.
*
* @param layer
* The layer to search from.
* @param type
* The synapse type to look for.
* @return The synapse found, or null.
*/
public Synapse findNextSynapseByLayerType(final Layer layer,
final Class<? extends Layer> type) {
for (final Synapse synapse : layer.getNext()) {
if (synapse.getToLayer().getClass() == type) {
return synapse;
}
}
return null;
}
/**
* Find previous synapse by layer type.
*
* @param layer
* The layer to start from.
* @param type
* The type of layer.
* @return The synapse found.
*/
public Synapse findPreviousSynapseByLayerType(final Layer layer,
final Class<? extends Layer> type) {
for (final Synapse synapse : getPreviousSynapses(layer)) {
if (synapse.getFromLayer().getClass() == type) {
return synapse;
}
}
return null;
}
/**
* Find the specified synapse, throw an error if it is required.
*
* @param fromLayer
* The from layer.
* @param toLayer
* The to layer.
* @param required
* Is this required?
* @return The synapse, if it exists, otherwise null.
*/
public Synapse findSynapse(final Layer fromLayer, final Layer toLayer,
final boolean required) {
Synapse result = null;
for (final Synapse synapse : getSynapses()) {
if ((synapse.getFromLayer() == fromLayer)
&& (synapse.getToLayer() == toLayer)) {
result = synapse;
break;
}
}
if (required && (result == null)) {
final String str = "This operation requires a network with a synapse between the "
+ nameLayer(fromLayer)
+ " layer to the "
+ nameLayer(toLayer) + " layer.";
if (NeuralStructure.LOGGER.isErrorEnabled()) {
NeuralStructure.LOGGER.error(str);
}
throw new NeuralNetworkError(str);
}
return result;
}
/**
* Flatten the network. Generate the flat network.
*/
public void flatten() {
final boolean isRBF = false;
final Map<Layer, FlatLayer> regular2flat = new HashMap<Layer, FlatLayer>();
final Map<FlatLayer, Layer> flat2regular = new HashMap<FlatLayer, Layer>();
final List<ObjectPair<Layer, Layer>> contexts = new ArrayList<ObjectPair<Layer, Layer>>();
this.flat = null;
final ValidateForFlat val = new ValidateForFlat();
if (val.isValid(this.network) == null) {
if ((this.layers.size() == 3)
&& (this.layers.get(1) instanceof RadialBasisFunctionLayer)) {
final RadialBasisFunctionLayer rbf = (RadialBasisFunctionLayer) this.layers
.get(1);
for(Layer layer: this.layers ) {
if( layer.hasBias() ) {
throw new NeuralNetworkError("Bias cannot be used with an RBF neural network.");
}
}
this.flat = new FlatNetworkRBF(this.network.getInputCount(),
rbf.getNeuronCount(), this.network.getOutputCount(),
rbf.getRadialBasisFunction());
flattenWeights();
this.flatUpdate = FlatUpdateNeeded.None;
return;
}
int flatLayerCount = countNonContext();
final FlatLayer[] flatLayers = new FlatLayer[flatLayerCount];
int index = flatLayers.length - 1;
for (final Layer layer : this.layers) {
if (layer instanceof ContextLayer) {
final Synapse inboundSynapse = this.network.getStructure()
.findPreviousSynapseByLayerType(layer,
BasicLayer.class);
final Synapse outboundSynapse = this.network
.getStructure()
.findNextSynapseByLayerType(layer, BasicLayer.class);
if (inboundSynapse == null) {
throw new NeuralNetworkError(
"Context layer must be connected to by one BasicLayer.");
}
if (outboundSynapse == null) {
throw new NeuralNetworkError(
"Context layer must connect to by one BasicLayer.");
}
final Layer inbound = inboundSynapse.getFromLayer();
final Layer outbound = outboundSynapse.getToLayer();
contexts
.add(new ObjectPair<Layer, Layer>(inbound, outbound));
} else {
final double bias = findNextBias(layer);
ActivationFunction activationType;
double[] params = new double[1];
if (layer.getActivationFunction() == null) {
activationType = new ActivationLinear();
params = new double[1];
params[0] = 1;
} else {
activationType = layer.getActivationFunction();
params = layer.getActivationFunction().getParams();
}
final FlatLayer flatLayer = new FlatLayer(activationType,
layer.getNeuronCount(), bias, params);
regular2flat.put(layer, flatLayer);
flat2regular.put(flatLayer, layer);
flatLayers[index--] = flatLayer;
}
}
// now link up the context layers
for (final ObjectPair<Layer, Layer> context : contexts) {
final Layer layer = context.getB();
final Synapse synapse = this.network
.getStructure()
.findPreviousSynapseByLayerType(layer, BasicLayer.class);
final FlatLayer from = regular2flat.get(context.getA());
final FlatLayer to = regular2flat.get(synapse.getFromLayer());
to.setContextFedBy(from);
}
this.flat = new FlatNetwork(flatLayers);
// update the context indexes on the non-flat network
for (int i = 0; i < flatLayerCount; i++) {
FlatLayer fedBy = flatLayers[i].getContextFedBy();
if (fedBy != null) {
Layer fedBy2 = flat2regular.get(flatLayers[i+1]);
Synapse synapse = findPreviousSynapseByLayerType(fedBy2,
ContextLayer.class);
if (synapse == null)
throw new NeuralNetworkError(
"Can't find parent synapse to context layer.");
ContextLayer context = (ContextLayer) synapse
.getFromLayer();
// find fedby index
int fedByIndex = -1;
for (int j = 0; j < flatLayerCount; j++) {
if (flatLayers[j] == fedBy) {
fedByIndex = j;
break;
}
}
if (fedByIndex == -1)
throw new NeuralNetworkError(
"Can't find layer feeding context.");
context.setFlatContextIndex(this.flat
.getContextTargetOffset()[fedByIndex]);
}
}
// RBF networks will not train every layer
if (isRBF) {
this.flat.setEndTraining(flatLayers.length - 1);
}
flattenWeights();
this.flatUpdate = FlatUpdateNeeded.None;
} else {
this.flatUpdate = FlatUpdateNeeded.Never;
}
}
/**
* Flatten the weights, do not restructure.
*/
public void flattenWeights() {
if (this.flat != null) {
this.flatUpdate = FlatUpdateNeeded.Flatten;
final double[] targetWeights = this.flat.getWeights();
final double[] sourceWeights = NetworkCODEC
.networkToArray(this.network);
EngineArray.arrayCopy(sourceWeights, targetWeights);
this.flatUpdate = FlatUpdateNeeded.None;
// update context layers
for (Layer layer : this.layers) {
if (layer instanceof ContextLayer) {
ContextLayer context = (ContextLayer) layer;
if (context.getFlatContextIndex() != -1) {
EngineArray.arrayCopy(context.getContext().getData(),
0, this.flat.getLayerOutput(), context
.getFlatContextIndex(), context
.getContext().size());
}
}
}
// handle limited connection networks
if (this.connectionLimited) {
this.flat.setConnectionLimit(this.connectionLimit);
} else {
this.flat.clearConnectionLimit();
}
}
}
/**
* @return The connection limit.
*/
public double getConnectionLimit() {
return this.connectionLimit;
}
/**
* @return The flat network.
*/
public FlatNetwork getFlat() {
return this.flat;
}
/**
* @return The type of update currently needed.
*/
public FlatUpdateNeeded getFlatUpdate() {
return this.flatUpdate;
}
/**
* @return The layers in this neural network.
*/
public List<Layer> getLayers() {
return this.layers;
}
/**
* Called to help build the layer structure.
*
* @param result
* The layer list.
* @param layer
* The current layer being processed.
*/
private void getLayers(final List<Layer> result, final Layer layer) {
if (!result.contains(layer)) {
result.add(layer);
}
for (final Synapse synapse : layer.getNext()) {
final Layer nextLayer = synapse.getToLayer();
if (!result.contains(nextLayer)) {
getLayers(result, nextLayer);
}
}
}
/**
* @return The network this structure belongs to.
*/
public BasicNetwork getNetwork() {
return this.network;
}
/**
* Get the next layer id.
*
* @return The next layer id.
*/
public int getNextID() {
return this.nextID++;
}
/**
* Get the previous layers from the specified layer.
*
* @param targetLayer
* The target layer.
* @return The previous layers.
*/
public Collection<Layer> getPreviousLayers(final Layer targetLayer) {
final Collection<Layer> result = new HashSet<Layer>();
for (final Layer layer : this.getLayers()) {
for (final Synapse synapse : layer.getNext()) {
if (synapse.getToLayer() == targetLayer) {
result.add(synapse.getFromLayer());
}
}
}
return result;
}
/**
* Get the previous synapses.
*
* @param targetLayer
* The layer to get the previous layers from.
* @return A collection of synapses.
*/
public List<Synapse> getPreviousSynapses(final Layer targetLayer) {
final List<Synapse> result = new ArrayList<Synapse>();
for (final Synapse synapse : this.synapses) {
if (synapse.getToLayer() == targetLayer) {
if (!result.contains(synapse)) {
result.add(synapse);
}
}
}
return result;
}
/**
* @return All synapses in the neural network.
*/
public List<Synapse> getSynapses() {
return this.synapses;
}
/**
* @return True if this is not a fully connected feedforward network.
*/
public boolean isConnectionLimited() {
return this.connectionLimited;
}
/**
* @return Are there any context layers.
*/
public boolean isRecurrent() {
for (final Layer layer : this.getLayers()) {
if (layer instanceof ContextLayer) {
return true;
}
}
return false;
}
/**
* Obtain a name for the specified layer.
*
* @param layer
* The layer to name.
* @return The name of this layer.
*/
public List<String> nameLayer(final Layer layer) {
final List<String> result = new ArrayList<String>();
for (final Entry<String, Layer> entry : this.network.getLayerTags()
.entrySet()) {
if (entry.getValue() == layer) {
result.add(entry.getKey());
}
}
return result;
}
/**
* Set the type of flat update needed.
*
* @param flatUpdate
* The type of flat update needed.
*/
public void setFlatUpdate(final FlatUpdateNeeded flatUpdate) {
this.flatUpdate = flatUpdate;
}
/**
* Sort the layers and synapses.
*/
public void sort() {
Collections.sort(this.layers, new LayerComparator(this));
Collections.sort(this.synapses, new SynapseComparator(this));
}
/**
* Unflatten the weights.
*/
public void unflattenWeights() {
if (flat != null) {
double[] sourceWeights = flat.getWeights();
NetworkCODEC.arrayToNetwork(sourceWeights, network);
this.flatUpdate = FlatUpdateNeeded.None;
// update context layers
for (Layer layer : this.layers) {
if (layer instanceof ContextLayer) {
ContextLayer context = (ContextLayer) layer;
if (context.getFlatContextIndex() != -1) {
EngineArray.arrayCopy(this.flat.getLayerOutput(),
context.getFlatContextIndex(), context
.getContext().getData(), 0, context
.getContext().size());
}
}
}
}
}
/**
* Update the flat network.
*/
public void updateFlatNetwork() {
// if flatUpdate is null, the network was likely just loaded from a
// serialized file
if (this.flatUpdate == null) {
flattenWeights();
this.flatUpdate = FlatUpdateNeeded.None;
}
switch (this.flatUpdate) {
case Flatten:
flattenWeights();
break;
case Unflatten:
unflattenWeights();
break;
case None:
case Never:
return;
}
this.flatUpdate = FlatUpdateNeeded.None;
}
}