//----------------------------------------------------------------------------//
// //
// N e u r a l N e t w o r k //
// //
//----------------------------------------------------------------------------//
// <editor-fold defaultstate="collapsed" desc="hdr"> //
// Copyright © Hervé Bitteur and others 2000-2013. All rights reserved. //
// This software is released under the GNU General Public License. //
// Goto http://kenai.com/projects/audiveris to report bugs or suggestions. //
//----------------------------------------------------------------------------//
// </editor-fold>
package omr.math;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.InputStream;
import java.io.OutputStream;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller;
import javax.xml.bind.annotation.XmlAccessType;
import javax.xml.bind.annotation.XmlAccessorType;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlElementWrapper;
import javax.xml.bind.annotation.XmlRootElement;
/**
* Class {@code NeuralNetwork} implements a back-propagation neural
* network, with one input layer, one hidden layer and one output layer.
* The transfer function is the sigmoid.
*
* <p>This neuralNetwork class can be stored on disk in XML form (through
* the {@link #marshal} and {@link #unmarshal} methods).
*
* <p>The class also allows in-memory {@link #backup} and {@link #restore}
* operation, mainly used to save the most performant weight values during the
* network training.
*
* @author Hervé Bitteur
*/
@XmlAccessorType(XmlAccessType.NONE)
@XmlRootElement(name = "neural-network")
public class NeuralNetwork
{
//~ Static fields/initializers ---------------------------------------------
/** Usual logger utility */
private static final Logger logger = LoggerFactory.getLogger(
NeuralNetwork.class);
/** Un/marshalling context for use with JAXB */
private static volatile JAXBContext jaxbContext;
//~ Instance fields --------------------------------------------------------
//
/** Size of input layer. */
@XmlAttribute(name = "input-size")
private final int inputSize;
/** Size of hidden layer. */
@XmlAttribute(name = "hidden-size")
private final int hiddenSize;
/** Size of output layer. */
@XmlAttribute(name = "output-size")
private final int outputSize;
/** Labels of input cells. */
@XmlElementWrapper(name = "input-labels")
@XmlElement(name = "input")
private final String[] inputLabels;
/** Labels of output cells. */
@XmlElementWrapper(name = "output-labels")
@XmlElement(name = "output")
private final String[] outputLabels;
/** Weights to hidden layer. */
@XmlElementWrapper(name = "hidden-weights")
@XmlElement(name = "row")
private double[][] hiddenWeights;
/** Weights to output layer. */
@XmlElementWrapper(name = "output-weights")
@XmlElement(name = "row")
private double[][] outputWeights;
/** Flag to stop training. */
private transient volatile boolean stopping = false;
/** Learning Rate parameter. */
private transient volatile double learningRate = 0.40;
/** Max Error parameter. */
private transient volatile double maxError = 1E-4;
/** Momentum for faster convergence. */
private transient volatile double momentum = 0.25;
/** Number of epochs when training. */
private transient volatile int epochs = 1000;
//~ Constructors -----------------------------------------------------------
//---------------//
// NeuralNetwork //
//---------------//
/**
* Create a neural network, with specified number of cells in each
* layer, and default values.
*
* @param inputSize number of cells in input layer
* @param hiddenSize number of cells in hidden layer
* @param outputSize number of cells in output layer
* @param amplitude amplitude ( <= 1.0) for initial random values
* @param inputLabels array
* o f
* labels
* for
* input
* cells,
* or
* null
* @param outputLabels array of labels for output cells, or null
*/
public NeuralNetwork (int inputSize,
int hiddenSize,
int outputSize,
double amplitude,
String[] inputLabels,
String[] outputLabels)
{
// Cache parameters
this.inputSize = inputSize;
this.hiddenSize = hiddenSize;
this.outputSize = outputSize;
// Allocate weights (from input) to hidden layer
// +1 for bias
hiddenWeights = createMatrix(hiddenSize, inputSize + 1, amplitude);
// Allocate weights (from hidden) to output layer
// +1 for bias
outputWeights = createMatrix(outputSize, hiddenSize + 1, amplitude);
// Labels for input, if any
this.inputLabels = inputLabels;
if ((inputLabels != null) && (inputLabels.length != inputSize)) {
throw new IllegalArgumentException(
"Inconsistent input labels " + inputLabels + " vs "
+ inputSize);
}
// Labels for output, if any
this.outputLabels = outputLabels;
if ((outputLabels != null) && (outputLabels.length != outputSize)) {
throw new IllegalArgumentException(
"Inconsistent output labels " + outputLabels + " vs "
+ outputSize);
}
logger.debug("Network created");
}
//---------------//
// NeuralNetwork //
//---------------//
/**
* Create a neural network, with specified number of cells in each
* layer, and specific parameters
*
* @param inputSize number of cells in input layer
* @param hiddenSize number of cells in hidden layer
* @param outputSize number of cells in output layer
* @param amplitude amplitude ( <= 1.0) for initial random values
* @param inputLabels array
* o f
* labels
* for
* input
* cells,
* or
* null
* @param outputLabels array of labels for output cells, or null
* @param learningRate learning rate factor
* @param momentum momentum from last adjustment
* @param maxError threshold to stop training
* @param epochs number of epochs in training
*/
public NeuralNetwork (int inputSize,
int hiddenSize,
int outputSize,
double amplitude,
String[] inputLabels,
String[] outputLabels,
double learningRate,
double momentum,
double maxError,
int epochs)
{
this(
inputSize,
hiddenSize,
outputSize,
amplitude,
inputLabels,
outputLabels);
// Cache parameters
this.learningRate = learningRate;
this.momentum = momentum;
this.maxError = maxError;
this.epochs = epochs;
}
//---------------//
// NeuralNetwork //
//---------------//
/** Private no-arg constructor meant for the JAXB compiler only */
private NeuralNetwork ()
{
inputSize = -1;
hiddenSize = -1;
outputSize = -1;
inputLabels = null;
outputLabels = null;
}
//~ Methods ----------------------------------------------------------------
//-----------//
// unmarshal //
//-----------//
/**
* Unmarshal the provided XML stream to allocate the corresponding
* NeuralNetwork.
*
* @param in the input stream that contains the network definition in XML
* format. The stream is not closed by this method
*
* @return the allocated network.
* @exception JAXBException raised when unmarshalling goes wrong
*/
public static NeuralNetwork unmarshal (InputStream in)
throws JAXBException
{
Unmarshaller um = getJaxbContext()
.createUnmarshaller();
NeuralNetwork nn = (NeuralNetwork) um.unmarshal(in);
logger.debug("Network unmarshalled");
return nn;
}
//
//--------//
// backup //
//--------//
/**
* Return a backup of the internal memory of this network.
* Generally used right after network creation to save the initial
* conditions.
*
* @return an opaque copy of the network memory
*/
public Backup backup ()
{
logger.debug("Network memory backup");
return new Backup(hiddenWeights, outputWeights);
}
//------//
// dump //
//------//
/**
* Dumps the network
*/
public void dump ()
{
StringBuilder sb = new StringBuilder();
sb.append(String.format("Network%n"));
sb.append(String.format("LearningRate = %f%n", learningRate));
sb.append(String.format("Momentum = %f%n", momentum));
sb.append(String.format("MaxError = %f%n", maxError));
sb.append(String.format("Epochs = %d%n", epochs));
// Input
sb.append(String.format("%nInputs : %d cells%n", inputSize));
// Hidden
sb.append(dumpOfMatrix(hiddenWeights));
sb.append(String.format("%nHidden : %d cells%n", hiddenSize));
// Output
sb.append(dumpOfMatrix(outputWeights));
sb.append(String.format("%nOutputs : %d cells%n", outputSize));
logger.info(sb.toString());
}
//---------------//
// getHiddenSize //
//---------------//
/**
* Report the number of cells in the hidden layer
*
* @return the size of the hidden layer
*/
public int getHiddenSize ()
{
return hiddenSize;
}
//----------------//
// getInputLabels //
//----------------//
/**
* Report the input labels, if any.
*
* @return the inputLabels, perhaps null
*/
public String[] getInputLabels ()
{
return inputLabels;
}
//--------------//
// getInputSize //
//--------------//
/**
* Report the number of cells in the input layer
*
* @return the size of input layer
*/
public int getInputSize ()
{
return inputSize;
}
//-----------------//
// getOutputLabels //
//-----------------//
/**
* Report the output labels, if any.
*
* @return the outputLabels, perhaps null
*/
public String[] getOutputLabels ()
{
return outputLabels;
}
//---------------//
// getOutputSize //
//---------------//
/**
* Report the size of the output layer
*
* @return the number of cells in the output layer
*/
public int getOutputSize ()
{
return outputSize;
}
//---------//
// marshal //
//---------//
/**
* Marshal the NeuralNetwork to its XML file
*
* @param os the XML output stream, which is not closed by this method
* @exception JAXBException raised when marshalling goes wrong
*/
public void marshal (OutputStream os)
throws JAXBException
{
Marshaller m = getJaxbContext()
.createMarshaller();
m.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, true);
m.marshal(this, os);
logger.debug("Network marshalled");
}
//---------//
// restore //
//---------//
/**
* Restore the internal memory of a Network, from a previous Backup.
* This does not reset the current parameters such as learning rate,
* momentum, maxError or epochs.
*
* @param backup a backup previously made
*/
public void restore (Backup backup)
{
// Check parameter
if (backup == null) {
throw new IllegalArgumentException("Backup is null");
}
// Make sure backup is compatible with this neural network
if ((backup.hiddenWeights.length != hiddenSize)
|| (backup.hiddenWeights[0].length != (inputSize + 1))
|| (backup.outputWeights.length != outputSize)
|| (backup.outputWeights[0].length != (hiddenSize + 1))) {
throw new IllegalArgumentException("Incompatible backup");
}
logger.debug("Network memory restore");
this.hiddenWeights = cloneMatrix(backup.hiddenWeights);
this.outputWeights = cloneMatrix(backup.outputWeights);
}
//-----//
// run //
//-----//
/**
* Run the neural network on an array of input values, and return the
* computed output values.
* This method writes into the hiddens buffer.
*
* @param inputs the provided input values
* @param hiddens provided buffer for hidden values, or null
* @param outputs preallocated array for the computed output values, or null
* if not already allocated
*
* @return the computed output values
*/
public double[] run (double[] inputs,
double[] hiddens,
double[] outputs)
{
// Check size consistencies.
if (inputs == null) {
logger.error("run method. inputs array is null");
} else if (inputs.length != inputSize) {
logger.error(
"run method. input size {} not consistent with"
+ " network input layer {}",
inputs.length,
inputSize);
}
// Allocate the hiddens if not provided
if (hiddens == null) {
hiddens = new double[hiddenSize];
}
// Compute the hidden values
forward(inputs, hiddenWeights, hiddens);
// Allocate the outputs if not done yet
if (outputs == null) {
outputs = new double[outputSize];
} else if (outputs.length != outputSize) {
logger.error(
"run method. output size {} not consistent with"
+ " network output layer {}",
outputs.length,
outputSize);
}
// Then, compute the output values
forward(hiddens, outputWeights, outputs);
return outputs;
}
//-----------//
// setEpochs //
//-----------//
/**
* Set the number of iterations for training the network with a
* given input.
*
* @param epochs number of iterations
*/
public void setEpochs (int epochs)
{
this.epochs = epochs;
}
//-----------------//
// setLearningRate //
//-----------------//
/**
* Set the learning rate.
*
* @param learningRate the learning rate to use for each iteration
* (typically in the 0.0 .. 1.0 range)
*/
public void setLearningRate (double learningRate)
{
this.learningRate = learningRate;
}
//-------------//
// setMaxError //
//-------------//
/**
* Set the maximum error level.
*
* @param maxError maximum error
*/
public void setMaxError (double maxError)
{
this.maxError = maxError;
}
//-------------//
// setMomentum //
//-------------//
/**
* Set the momentum value.
*
* @param momentum the fraction of previous move to be reported on the next
* correction
*/
public void setMomentum (double momentum)
{
this.momentum = momentum;
}
//------//
// stop //
//------//
/**
* A means to externally stop the current training.
*/
public void stop ()
{
stopping = true;
logger.debug("Network training being stopped ...");
}
//-------//
// train //
//-------//
/**
* Train the neural network on a collection of input patterns,
* so that it delivers the expected outputs within maxError.
* This method is not optimized for absolute speed, but rather for being
* able to keep the best weights values.
*
* @param inputs the provided patterns of values for input cells
* @param desiredOutputs the corresponding desired values for output cells
* @param monitor a monitor interface to be kept informed (or null)
*
* @return mse, the final mean square error
*/
public double train (double[][] inputs,
double[][] desiredOutputs,
Monitor monitor)
{
logger.debug("Network being trained");
stopping = false;
long startTime = System.currentTimeMillis();
// Check size consistencies.
if (inputs == null) {
throw new IllegalArgumentException("inputs array is null");
}
final int patternNb = inputs.length;
if (desiredOutputs == null) {
throw new IllegalArgumentException("desiredOutputs array is null");
}
// Allocate needed arrays
double[] gottenOutputs = new double[outputSize];
double[] hiddenGrads = new double[hiddenSize];
double[] outputGrads = new double[outputSize];
double[][] hiddenDeltas = createMatrix(hiddenSize, inputSize + 1, 0);
double[][] outputDeltas = createMatrix(outputSize, hiddenSize + 1, 0);
double[] hiddens = new double[hiddenSize];
// Mean Square Error
double mse = 0;
// Notify Monitor we are starting
if (monitor != null) {
// Compute the initial mse
for (int ip = 0; ip < patternNb; ip++) {
run(inputs[ip], hiddens, gottenOutputs);
for (int o = outputSize - 1; o >= 0; o--) {
double out = gottenOutputs[o];
double dif = desiredOutputs[ip][o] - out;
mse += (dif * dif);
}
}
mse /= patternNb;
mse = Math.sqrt(mse);
monitor.trainingStarted(0, mse);
}
int ie = 0;
for (; ie < epochs; ie++) {
// Have we been told to stop ?
if (stopping) {
logger.debug("Network stopped.");
break;
}
// Compute the output layer error terms
mse = 0;
// Loop on all input patterns
for (int ip = 0; ip < patternNb; ip++) {
// Run the network with input values and current weights
run(inputs[ip], hiddens, gottenOutputs);
for (int o = outputSize - 1; o >= 0; o--) {
double out = gottenOutputs[o];
double dif = desiredOutputs[ip][o] - out;
mse += (dif * dif);
outputGrads[o] = dif * out * (1 - out);
}
// Compute the hidden layer error terms
for (int h = hiddenSize - 1; h >= 0; h--) {
double sum = 0;
double hid = hiddens[h];
for (int o = outputSize - 1; o >= 0; o--) {
sum += (outputGrads[o] * outputWeights[o][h + 1]);
}
hiddenGrads[h] = sum * hid * (1 - hid);
}
// Now update the output weights
for (int o = outputSize - 1; o >= 0; o--) {
for (int h = hiddenSize - 1; h >= 0; h--) {
double dw = (learningRate * outputGrads[o] * hiddens[h])
+ (momentum * outputDeltas[o][h + 1]);
outputWeights[o][h + 1] += dw;
outputDeltas[o][h + 1] = dw;
}
// Bias
double dw = (learningRate * outputGrads[o])
+ (momentum * outputDeltas[o][0]);
outputWeights[o][0] += dw;
outputDeltas[o][0] = dw;
}
// And the hidden weights
for (int h = hiddenSize - 1; h >= 0; h--) {
for (int i = inputSize - 1; i >= 0; i--) {
double dw = (learningRate * hiddenGrads[h] * inputs[ip][i])
+ (momentum * hiddenDeltas[h][i + 1]);
hiddenWeights[h][i + 1] += dw;
hiddenDeltas[h][i + 1] = dw;
}
// Bias
double dw = (learningRate * hiddenGrads[h])
+ (momentum * hiddenDeltas[h][0]);
hiddenWeights[h][0] += dw;
hiddenDeltas[h][0] = dw;
}
} // for (int ip = 0; i < patternNb; i++)
// Compute true current mse
mse = 0d;
for (int ip = 0; ip < patternNb; ip++) {
run(inputs[ip], hiddens, gottenOutputs);
for (int o = outputSize - 1; o >= 0; o--) {
double out = gottenOutputs[o];
double dif = desiredOutputs[ip][o] - out;
mse += (dif * dif);
}
}
mse /= patternNb;
mse = Math.sqrt(mse);
if (monitor != null) {
monitor.epochEnded(ie, mse);
}
if (mse <= maxError) {
logger.info(
"Network exiting training, remaining error limit reached");
logger.info("Network remaining error was : {}", mse);
break;
}
} // for (int ie = 0; ie < epochs; ie++)
if (logger.isDebugEnabled()) {
long stopTime = System.currentTimeMillis();
logger.debug(
String.format(
"Duration %,d seconds, %d epochs on %d patterns",
(stopTime - startTime) / 1000,
ie,
patternNb));
}
return mse;
}
//-------------//
// cloneMatrix //
//-------------//
/**
* Create a clone of the provided matrix.
*
* @param matrix the matrix to clone
* @return the clone
*/
private static double[][] cloneMatrix (double[][] matrix)
{
final int rowNb = matrix.length;
final int colNb = matrix[0].length;
double[][] clone = new double[rowNb][];
for (int row = rowNb - 1; row >= 0; row--) {
clone[row] = new double[colNb];
System.arraycopy(matrix[row], 0, clone[row], 0, colNb);
}
return clone;
}
//--------------//
// createMatrix //
//--------------//
/**
* Create and initialize a matrix, with random values.
* Random values are between -amplitude and +amplitude
*
* @param rowNb number of rows
* @param colNb number of columns
*
* @return the properly initialized matrix
*/
private static double[][] createMatrix (int rowNb,
int colNb,
double amplitude)
{
double[][] matrix = new double[rowNb][];
for (int row = rowNb - 1; row >= 0; row--) {
double[] vector = new double[colNb];
matrix[row] = vector;
for (int col = colNb - 1; col >= 0; col--) {
vector[col] = amplitude * (1.0 - (2 * Math.random()));
}
}
return matrix;
}
//------------//
// dumpMatrix //
//------------//
/**
* Dump a matrix (assumed to be a true rectangular matrix,
* with all rows of the same length).
*
* @param matrix the matrix to dump
* @return the matrix representation
*/
private String dumpOfMatrix (double[][] matrix)
{
StringBuilder sb = new StringBuilder();
for (int col = 0; col < matrix[0].length; col++) {
sb.append(String.format("%14d", col));
}
sb.append(String.format("%n"));
for (int row = 0; row < matrix.length; row++) {
sb.append(String.format("%2d:", row));
for (int col = 0; col < matrix[0].length; col++) {
sb.append(String.format("%14e", matrix[row][col]));
}
sb.append(String.format("%n"));
}
return sb.toString();
}
//---------//
// forward //
//---------//
/**
* Re-entrant method.
*
* @param ins input cells
* @param weights applied weights
* @param outs output cells
*/
private void forward (double[] ins,
double[][] weights,
double[] outs)
{
double sum;
double[] ws;
for (int o = outs.length - 1; o >= 0; o--) {
sum = 0;
ws = weights[o];
for (int i = ins.length - 1; i >= 0; i--) {
sum += (ws[i + 1] * ins[i]);
}
// Bias
sum += ws[0];
outs[o] = sigmoid(sum);
}
}
//----------------//
// getJaxbContext //
//----------------//
private static JAXBContext getJaxbContext ()
throws JAXBException
{
// Lazy creation
if (jaxbContext == null) {
jaxbContext = JAXBContext.newInstance(NeuralNetwork.class);
}
return jaxbContext;
}
//---------//
// sigmoid //
//---------//
/**
* Simple sigmoid function, with a step around 0 abscissa.
*
* @param val abscissa
* @return the related function value
*/
private double sigmoid (double val)
{
return 1.0d / (1.0d + Math.exp(-val));
}
//~ Inner Interfaces -------------------------------------------------------
//
//---------//
// Monitor //
//---------//
/**
* Interface {@code Monitor} allows to plug a monitor to a Neural
* Network instance, and inform the monitor about the progress of
* the training activity.
*/
public static interface Monitor
{
//~ Methods ------------------------------------------------------------
/**
* Entry called at end of each epoch during the training phase.
*
* @param epochIndex the sequential index of completed epoch
* @param mse the remaining mean square error
*/
void epochEnded (int epochIndex,
double mse);
/**
* Entry called at the beginning of the training phase, to allow
* initial snap shots for example.
*
* @param epochIndex the sequential index (0)
* @param mse the starting mean square error
* */
void trainingStarted (final int epochIndex,
final double mse);
}
//~ Inner Classes ----------------------------------------------------------
//
//--------//
// Backup //
//--------//
/**
* Class {@code Backup} is an opaque class that encapsulates a
* snapshot of a NeuralNetwork internal memory (its weights).
* A Backup instance can only be obtained through the use of {@link #backup}
* method of a NeuralNetwork.
* A Backup instance is the needed parameter for a NeuralNetwork {@link
* #restore} action.
*/
public static class Backup
{
//~ Instance fields ----------------------------------------------------
private double[][] hiddenWeights;
private double[][] outputWeights;
//~ Constructors -------------------------------------------------------
// Private constructor
private Backup (double[][] hiddenWeights,
double[][] outputWeights)
{
this.hiddenWeights = cloneMatrix(hiddenWeights);
this.outputWeights = cloneMatrix(outputWeights);
}
}
}