//----------------------------------------------------------------------------// // // // G l y p h N e t w o r k // // // //----------------------------------------------------------------------------// // <editor-fold defaultstate="collapsed" desc="hdr"> // // Copyright © Hervé Bitteur and others 2000-2013. All rights reserved. // // This software is released under the GNU General Public License. // // Goto http://kenai.com/projects/audiveris to report bugs or suggestions. // //----------------------------------------------------------------------------// // </editor-fold> package omr.glyph; import omr.constant.Constant; import omr.constant.ConstantSet; import omr.glyph.facets.Glyph; import omr.math.NeuralNetwork; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.EnumMap; import java.util.List; import javax.xml.bind.JAXBException; /** * Class {@code GlyphNetwork} encapsulates a neural network customized * for glyph recognition. * It wraps the generic {@link NeuralNetwork} with application * information, for training, storing, loading and using the neural network. * * <p>The application neural network data is loaded as follows: <ol> * <li>It first tries to find a file named 'eval/neural-network.xml' in the * application user area. * If any, this file contains a custom definition of the network, typically * after a user training.</li> * * <li>If not found, it falls back reading the default definition from the * application resource, reading the 'res/neural-network.xml' file in the * application program area.</ol></p> * * <p>After a user training of the neural network, the data is stored as * the custom definition in the user local file 'eval/neural-network.xml', * which will be picked up first when the application is run again.</p> * * @author Hervé Bitteur */ public class GlyphNetwork extends AbstractEvaluationEngine { //~ Static fields/initializers --------------------------------------------- /** Specific application parameters */ private static final Constants constants = new Constants(); /** Usual logger utility */ private static final Logger logger = LoggerFactory.getLogger(GlyphNetwork.class); /** The singleton. */ private static volatile GlyphNetwork INSTANCE; /** Neural network file name. */ private static final String FILE_NAME = "neural-network.xml"; //~ Instance fields -------------------------------------------------------- // /** The underlying neural network. */ private NeuralNetwork engine; //~ Constructors ----------------------------------------------------------- // //--------------// // GlyphNetwork // //--------------// /** * Private constructor, to create a glyph neural network. */ private GlyphNetwork () { // Unmarshal from user or default data, if compatible engine = (NeuralNetwork) unmarshal(); if (engine == null) { // Get a brand new one (not trained) logger.info("Creating a brand new {}", getName()); engine = createNetwork(); } } //~ Methods ---------------------------------------------------------------- // //-------------// // getInstance // //-------------// /** * Report the single instance of GlyphNetwork in the application. * * @return the instance */ public static GlyphNetwork getInstance () { if (INSTANCE == null) { synchronized (GlyphNetwork.class) { if (INSTANCE == null) { INSTANCE = new GlyphNetwork(); } } } return INSTANCE; } //--------------// // isCompatible // //--------------// @Override protected final boolean isCompatible (Object obj) { if (obj instanceof NeuralNetwork) { NeuralNetwork anEngine = (NeuralNetwork) obj; if (!Arrays.equals(anEngine.getInputLabels(), ShapeDescription.getParameterLabels())) { if (logger.isDebugEnabled()) { logger.debug("Engine inputs: {}", Arrays.toString(anEngine.getInputLabels())); logger.debug("Shape inputs: {}", Arrays.toString(ShapeDescription.getParameterLabels())); } return false; } if (!Arrays.equals(anEngine.getOutputLabels(), ShapeSet.getPhysicalShapeNames())) { if (logger.isDebugEnabled()) { logger.debug("Engine outputs: {}", Arrays.toString(anEngine.getOutputLabels())); logger.debug("Physical shapes: {}", Arrays.toString(ShapeSet.getPhysicalShapeNames())); } return false; } return true; } else { return false; } } //------// // dump // //------// /** * Dump the internals of the neural network to the standard output. */ @Override public void dump () { engine.dump(); } //--------------// // getAmplitude // //--------------// /** * Selector for the amplitude value (used in initial random values). * * @return the amplitude value */ public double getAmplitude () { return constants.amplitude.getValue(); } //-----------------// // getLearningRate // //-----------------// /** * Selector of the current value for network learning rate. * * @return the current learning rate */ public double getLearningRate () { return constants.learningRate.getValue(); } //---------------// // getListEpochs // //---------------// /** * Selector on the maximum numner of training iterations. * * @return the upper limit on iteration counter */ public int getListEpochs () { return constants.listEpochs.getValue(); } //-------------// // getMaxError // //-------------// /** * Report the error threshold to potentially stop the training * process. * * @return the threshold currently in use */ public double getMaxError () { return constants.maxError.getValue(); } //-------------// // getMomentum // //-------------// /** * Report the momentum training value currently in use. * * @return the momentum in use */ public double getMomentum () { return constants.momentum.getValue(); } //---------// // getName // //---------// /** * Report a name for this network. * * @return a simple name */ @Override public final String getName () { return "Neural Network"; } //------------// // getNetwork // //------------// /** * Selector to the encapsulated Neural Network. * * @return the neural network */ public NeuralNetwork getNetwork () { return engine; } //--------------// // setAmplitude // //--------------// /** * Set the amplitude value for initial random values (UNUSED). * * @param amplitude */ public void setAmplitude (double amplitude) { constants.amplitude.setValue(amplitude); } //-----------------// // setLearningRate // //-----------------// /** * Dynamically modify the learning rate of the neural network for * its training task. * * @param learningRate new learning rate to use */ public void setLearningRate (double learningRate) { constants.learningRate.setValue(learningRate); engine.setLearningRate(learningRate); } //---------------// // setListEpochs // //---------------// /** * Modify the upper limit on the number of epochs (training * iterations) for the training process. * * @param listEpochs new value for iteration limit */ public void setListEpochs (int listEpochs) { constants.listEpochs.setValue(listEpochs); engine.setEpochs(listEpochs); } //-------------// // setMaxError // //-------------// /** * Modify the error threshold to potentially stop the training * process. * * @param maxError the new threshold value to use */ public void setMaxError (double maxError) { constants.maxError.setValue(maxError); engine.setMaxError(maxError); } //-------------// // setMomentum // //-------------// /** * Modify the value for momentum used from learning epoch to the * other. * * @param momentum the new momentum value to be used */ public void setMomentum (double momentum) { constants.momentum.setValue(momentum); engine.setMomentum(momentum); } //------// // stop // //------// /** * Forward the "Stop" order to the network being trained. */ @Override public void stop () { engine.stop(); } //-------// // train // //-------// /** * Train the network using the provided collection of glyphs. * * @param glyphs the provided collection of glyphs * @param monitor the monitoring entity if any * @param mode the starting mode of the trainer (scratch or incremental) */ @SuppressWarnings("unchecked") @Override public void train (Collection<Glyph> glyphs, Monitor monitor, StartingMode mode) { if (glyphs.isEmpty()) { logger.warn("No glyph to retrain Neural Network evaluator"); return; } int quorum = constants.quorum.getValue(); // Determine cardinality for each shape EnumMap<Shape, List<Glyph>> shapeGlyphs = new EnumMap<>(Shape.class); for (Glyph glyph : glyphs) { Shape shape = glyph.getShape(); List<Glyph> list = shapeGlyphs.get(shape); if (list == null) { list = new ArrayList<>(); shapeGlyphs.put(shape, list); } list.add(glyph); } List<Glyph> newGlyphs = new ArrayList<>(); for (List<Glyph> list : shapeGlyphs.values()) { int card = 0; boolean first = true; if (!list.isEmpty()) { while (card < quorum) { for (int i = 0; i < list.size(); i++) { newGlyphs.add(list.get(i)); card++; if (!first && (card >= quorum)) { break; } } first = false; } } } // Shuffle the final collection of glyphs Collections.shuffle(newGlyphs); // Build the collection of patterns from the glyph data double[][] inputs = new double[newGlyphs.size()][]; double[][] desiredOutputs = new double[newGlyphs.size()][]; int ig = 0; for (Glyph glyph : newGlyphs) { double[] ins = ShapeDescription.features(glyph); inputs[ig] = ins; double[] des = new double[shapeCount]; Arrays.fill(des, 0); des[glyph.getShape().getPhysicalShape().ordinal()] = 1; desiredOutputs[ig] = des; ig++; } // Starting options if (mode == StartingMode.SCRATCH) { engine = createNetwork(); } // Train on the patterns engine.train(inputs, desiredOutputs, monitor); } //-------------// // getFileName // //-------------// @Override protected String getFileName () { return FILE_NAME; } //-------------------// // getRawEvaluations // //-------------------// @Override protected Evaluation[] getRawEvaluations (Glyph glyph) { // If too small, it's just NOISE if (!isBigEnough(glyph)) { return noiseEvaluations; } else { double[] ins = ShapeDescription.features(glyph); double[] outs = new double[shapeCount]; Evaluation[] evals = new Evaluation[shapeCount]; Shape[] values = Shape.values(); engine.run(ins, null, outs); for (int s = 0; s < shapeCount; s++) { Shape shape = values[s]; // Use a grade in 0 .. 100 range evals[s] = new Evaluation(shape, 100 * outs[s]); } // Order the evals from best to worst Arrays.sort(evals); return evals; } } //---------// // marshal // //---------// @Override protected void marshal (OutputStream os) throws FileNotFoundException, IOException, JAXBException { engine.marshal(os); } //-----------// // unmarshal // //-----------// @Override protected NeuralNetwork unmarshal (InputStream is) throws JAXBException, IOException { return NeuralNetwork.unmarshal(is); } //---------------// // createNetwork // //---------------// private NeuralNetwork createNetwork () { // Note : We allocate a hidden layer with as many cells as the output // layer NeuralNetwork nn = new NeuralNetwork( ShapeDescription.length(), shapeCount, shapeCount, getAmplitude(), ShapeDescription.getParameterLabels(), // Input labels ShapeSet.getPhysicalShapeNames(), // Output labels getLearningRate(), getMomentum(), getMaxError(), getListEpochs()); return nn; } //~ Inner Classes ---------------------------------------------------------- private static final class Constants extends ConstantSet { //~ Instance fields ---------------------------------------------------- Constant.Ratio amplitude = new Constant.Ratio( 0.5, "Initial weight amplitude"); Constant.Ratio learningRate = new Constant.Ratio( 0.2, "Learning Rate"); Constant.Integer listEpochs = new Constant.Integer( "Epochs", 4000, "Number of epochs for training on list of glyphs"); Constant.Integer quorum = new Constant.Integer( "Glyphs", 10, "Minimum number of glyphs for each shape"); Evaluation.Grade maxError = new Evaluation.Grade( 1E-3, "Threshold to stop training"); Constant.Ratio momentum = new Constant.Ratio(0.2, "Training momentum"); } }