package func.nn;
import java.util.ArrayList;
import java.util.List;
import util.linalg.Vector;
/**
* A layered neural network
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public abstract class LayeredNetwork extends NeuralNetwork {
/**
* The input layer
*/
private Layer inputLayer;
/**
* The output layer
*/
private Layer outputLayer;
/**
* The list of middle layers
*/
private List hiddenLayers;
/**
* The cached list of links
*/
private List links = null;
/**
* Make a new layered network
*/
public LayeredNetwork() {
hiddenLayers = new ArrayList();
}
/**
* @see Network#getOutputValues()
*/
public Vector getOutputValues() {
return outputLayer.getActivations();
}
/**
* @see Network#setInputValues(double[])
*/
public void setInputValues(Vector values) {
inputLayer.setActivations(values);
}
/**
* Get the index of the node with the largest value
* @return the index
*/
public int getDiscreteOutputValue() {
return outputLayer.getGreatestActivationIndex();
}
/**
* Get the binary output value
* @return the binary output value
*/
public boolean getBinaryOutputValue() {
return outputLayer.getNode(0).getActivation() > .5;
}
/**
* Get the input layer
* @return the layer
*/
public Layer getInputLayer() {
return inputLayer;
}
/**
* Get the list of middle layers
* @return the list
*/
public List getHiddenLayers() {
return hiddenLayers;
}
/**
* Get the output layer
* @return the layer
*/
public Layer getOutputLayer() {
return outputLayer;
}
/**
* Set the input layer
* @param layer the new layer
*/
public void setInputLayer(Layer layer) {
inputLayer = layer;
}
/**
* Set the output layer
* @param layer the output layer
*/
public void setOutputLayer(Layer layer) {
outputLayer = layer;
}
/**
* Get the middle layer count
* @return the middle layer count
*/
public int getHiddenLayerCount() {
return hiddenLayers.size();
}
/**
* Get the middle layer
* @param i the index of the middle layer
* @return the layer
*/
public Layer getHiddenLayer(int i) {
return (Layer) hiddenLayers.get(i);
}
/**
* Add a middle layer
* @param layer the layer to add
*/
public void addHiddenLayer(Layer layer) {
hiddenLayers.add(hiddenLayers.size(), layer);
}
/**
* Disconnect this network
*/
public void disconnect() {
if (inputLayer != null && getHiddenLayerCount() > 0) {
Layer firstMiddle = getHiddenLayer(0);
inputLayer.disconnect(firstMiddle);
} else if (inputLayer != null && outputLayer != null) {
inputLayer.disconnect(outputLayer);
}
for (int i = 0; i + 1 < getHiddenLayerCount(); i++) {
Layer first = getHiddenLayer(i);
Layer second = getHiddenLayer(i + 1);
first.disconnect(second);
}
if (outputLayer != null && getHiddenLayerCount() > 0) {
Layer lastMiddle = getHiddenLayer(getHiddenLayerCount() - 1);
lastMiddle.disconnect(outputLayer);
}
}
/**
* Connect this network
*/
public void connect() {
if (inputLayer != null && getHiddenLayerCount() > 0) {
Layer firstMiddle = getHiddenLayer(0);
inputLayer.connect(firstMiddle);
} else if (inputLayer != null && outputLayer != null) {
inputLayer.connect(outputLayer);
}
for (int i = 0; i + 1 < getHiddenLayerCount(); i++) {
Layer first = getHiddenLayer(i);
Layer second = getHiddenLayer(i + 1);
first.connect(second);
}
if (outputLayer != null && getHiddenLayerCount() > 0) {
Layer lastMiddle = getHiddenLayer(getHiddenLayerCount() - 1);
lastMiddle.connect(outputLayer);
}
}
/**
* @see nn.NeuralNetwork#getLinks()
*/
public List getLinks() {
if (links != null) {
return links;
}
links = new ArrayList();
links.addAll(inputLayer.getLinks());
for (int i = 0; i < getHiddenLayerCount(); i++) {
links.addAll(getHiddenLayer(i).getLinks());
}
links.addAll(outputLayer.getLinks());
return links;
}
}