//----------------------------------------------------------------------------//
// //
// N e t w o r k P a n e l //
// //
//----------------------------------------------------------------------------//
// <editor-fold defaultstate="collapsed" desc="hdr"> //
// Copyright © Hervé Bitteur and others 2000-2013. All rights reserved. //
// This software is released under the GNU General Public License. //
// Goto http://kenai.com/projects/audiveris to report bugs or suggestions. //
//----------------------------------------------------------------------------//
// </editor-fold>
package omr.glyph.ui.panel;
import omr.glyph.EvaluationEngine;
import omr.glyph.GlyphNetwork;
import omr.glyph.ui.panel.TrainingPanel.DumpAction;
import omr.math.NeuralNetwork;
import omr.ui.field.LDoubleField;
import omr.ui.field.LIntegerField;
import omr.ui.field.LTextField;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.awt.event.ActionEvent;
import java.text.DateFormat;
import java.util.Date;
import java.util.Observable;
import javax.swing.AbstractAction;
import javax.swing.JButton;
import javax.swing.JComponent;
import javax.swing.JOptionPane;
import javax.swing.KeyStroke;
import javax.swing.SwingUtilities;
import javax.swing.event.ChangeEvent;
import javax.swing.event.ChangeListener;
/**
* Class {@code NetworkPanel} is the user interface that handles the
* training of the neural network engine. It is a dedicated companion of
* class {@link GlyphTrainer}.
*
* @author Hervé Bitteur
*/
class NetworkPanel
extends TrainingPanel
{
//~ Static fields/initializers ---------------------------------------------
/** Usual logger utility */
private static final Logger logger = LoggerFactory.getLogger(
NetworkPanel.class);
//~ Instance fields --------------------------------------------------------
/** Best neural weights so far */
private NeuralNetwork.Backup bestSnap;
/** Last neural weights */
private NeuralNetwork.Backup lastSnap;
/** To display ETA as a date */
private DateFormat dateFormat = DateFormat.getDateTimeInstance(
DateFormat.MEDIUM, // Date
DateFormat.MEDIUM); // Time
/** Input field for Learning rate of the neural network */
private LDoubleField learningRate = new LDoubleField(
"Learning Rate",
"Learning rate of the neural network",
"%.2f");
/** Input field for Momentum value of the neural network */
private LDoubleField momentum = new LDoubleField(
"Momentum",
"Momentum value for the neural network",
"%.2f");
/** Output of Estimated time for end of training */
private LTextField eta = new LTextField(
"ETA",
"Estimated time for end of training");
/** Input field for Maximum number of iterations to perform */
private LIntegerField listEpochs = new LIntegerField(
"Epochs",
"Maximum number of iterations to perform");
/** Output for Index of best configuration so far */
private LIntegerField bestIndex = new LIntegerField(
false,
"Best Index",
"Index of best configuration so far");
/** Output for Number of iterations performed so far */
private LIntegerField trainIndex = new LIntegerField(
false,
"Last Index",
"Number of iterations performed so far");
/** Input field for Error threshold to stop learning */
private LDoubleField maxError = new LDoubleField(
"Max Error",
"Error threshold to stop learning");
/** Output for Best recorded value of remaining error */
private LDoubleField bestError = new LDoubleField(
false,
"Best Error",
"Best recorded value of remaining error");
/** Output for Last value of remaining error */
private LDoubleField trainError = new LDoubleField(
false,
"Last Error",
"Last value of remaining error");
/** User action to pick the last weight */
private LastAction lastAction = new LastAction();
/** User action to launch an incremental training */
private NetworkTrainAction incrementalTrainAction;
/** User action to pick the best recorded weights */
private BestAction bestAction = new BestAction();
/** User action to gracefully stop the training */
private StopAction stopAction = new StopAction();
/** Remaining error corresponding to best weights */
private double bestMse;
/** Potential listener on best error */
private final ChangeListener errorListener;
/** Event related to best error */
private final ChangeEvent errorEvent;
/** Remaining error corresponding to last run */
private double lastMse;
/** Training start time */
private long startTime;
//~ Constructors -----------------------------------------------------------
//--------------//
// NetworkPanel //
//--------------//
/**
* Creates a new NetworkPanel object.
*
*
* @param task the current training activity
* @param standardWidth standard width for fields & buttons
* @param errorListener a listener on remaining error
* @param selectionPanel the panel for glyph repository
*/
public NetworkPanel (GlyphTrainer.Task task,
String standardWidth,
ChangeListener errorListener,
SelectionPanel selectionPanel)
{
super(
task,
standardWidth,
GlyphNetwork.getInstance(),
selectionPanel,
6);
this.errorListener = errorListener;
task.addObserver(this);
if (errorListener != null) {
errorEvent = new ChangeEvent(this);
} else {
errorEvent = null;
}
eta.getField()
.setEditable(false); // ETA is just an output
component.getInputMap(JComponent.WHEN_ANCESTOR_OF_FOCUSED_COMPONENT)
.put(KeyStroke.getKeyStroke("ENTER"), "readParams");
component.getActionMap()
.put("readParams", new ParamAction());
trainAction = new NetworkTrainAction(
"Re-Train",
EvaluationEngine.StartingMode.SCRATCH,
/* confirmationRequired => */ true);
incrementalTrainAction = new NetworkTrainAction(
"Inc-Train",
EvaluationEngine.StartingMode.INCREMENTAL,
/* confirmationRequired => */ false);
defineSpecificLayout();
displayParams();
}
//~ Methods ----------------------------------------------------------------
//------------//
// epochEnded //
//------------//
@Override
public void epochEnded (final int epochIndex,
final double mse)
{
// This part is run on trainer thread
final int index = epochIndex + 1;
lastMse = mse;
boolean snap = false;
if (mse < bestMse) {
bestMse = mse;
// Take a snap
GlyphNetwork glyphNetwork = (GlyphNetwork) engine;
NeuralNetwork network = glyphNetwork.getNetwork();
bestSnap = network.backup();
snap = true;
// Belt & suspenders: make a copy on disk!
glyphNetwork.marshal();
}
final boolean snapTaken = snap;
SwingUtilities.invokeLater(
new Runnable()
{
// This part is run on swing thread
@Override
public void run ()
{
// Update current values
trainIndex.setValue(index);
trainError.setValue(mse);
// Update best values
if (snapTaken) {
bestIndex.setValue(index);
bestError.setValue(mse);
if (errorListener != null) {
errorListener.stateChanged(errorEvent);
}
}
// Update progress bar ?
progressBar.setValue(index);
// Compute ETA
long sofar = System.currentTimeMillis() - startTime;
long total = (GlyphNetwork.getInstance()
.getListEpochs() * sofar) / index;
Date etaDate = new Date(startTime + total);
eta.setText(dateFormat.format(etaDate));
component.repaint();
}
});
}
//--------------//
// getBestError //
//--------------//
/**
* Report the best remaining error so far
*
* @return the best error so far
*/
public double getBestError ()
{
return bestMse;
}
//-----------------//
// trainingStarted //
//-----------------//
@Override
public void trainingStarted (final int epochIndex,
final double mse)
{
// This part is run on trainer thread
final int index = epochIndex + 1;
NeuralNetwork network = ((GlyphNetwork) engine).getNetwork();
bestSnap = network.backup();
bestMse = mse;
SwingUtilities.invokeLater(
new Runnable()
{
// This part is run on swing thread
@Override
public void run ()
{
// Update best values
bestIndex.setValue(index);
bestError.setValue(mse);
if (errorListener != null) {
errorListener.stateChanged(errorEvent);
}
// Remember starting time
startTime = System.currentTimeMillis();
}
});
}
//--------//
// update //
//--------//
/**
* Specific behavior when a new task activity is notified. In addition to
* {@link TrainingPanel#update}, actions specific to training a neural
* network are handled here.
*
* @param obs the task object
* @param unused not used
*/
@Override
public void update (Observable obs,
Object unused)
{
super.update(obs, unused);
switch (task.getActivity()) {
case INACTIVE:
incrementalTrainAction.setEnabled(true);
stopAction.setEnabled(false);
break;
case SELECTING:
incrementalTrainAction.setEnabled(false);
stopAction.setEnabled(false);
break;
case TRAINING:
incrementalTrainAction.setEnabled(false);
stopAction.setEnabled(true);
inputParams();
displayParams();
bestMse = Double.MAX_VALUE;
bestSnap = null;
break;
}
bestAction.setEnabled(false);
lastAction.setEnabled(false);
}
//----------------------//
// defineSpecificLayout //
//----------------------//
private void defineSpecificLayout ()
{
int r = 3;
// ETA field
builder.add(eta.getLabel(), cst.xy(9, r));
builder.add(eta.getField(), cst.xyw(11, r, 5));
// Neural network parameters
r += 2; // ----------------------------
builder.add(momentum.getLabel(), cst.xy(9, r));
builder.add(momentum.getField(), cst.xy(11, r));
builder.add(learningRate.getLabel(), cst.xy(13, r));
builder.add(learningRate.getField(), cst.xy(15, r));
r += 2; // ----------------------------
builder.add(listEpochs.getLabel(), cst.xy(9, r));
builder.add(listEpochs.getField(), cst.xy(11, r));
builder.add(maxError.getLabel(), cst.xy(13, r));
builder.add(maxError.getField(), cst.xy(15, r));
// Training entities
r += 2; // ----------------------------
JButton dumpButton = new JButton(new DumpAction());
dumpButton.setToolTipText("Dump the evaluator internals");
JButton trainButton = new JButton(trainAction);
trainButton.setToolTipText("Re-Train the evaluator from scratch");
JButton bestButton = new JButton(bestAction);
bestButton.setToolTipText("Use the weights of best snap");
builder.add(dumpButton, cst.xy(3, r));
builder.add(trainButton, cst.xy(5, r));
builder.add(bestButton, cst.xy(7, r));
builder.add(bestIndex.getLabel(), cst.xy(9, r));
builder.add(bestIndex.getField(), cst.xy(11, r));
builder.add(bestError.getLabel(), cst.xy(13, r));
builder.add(bestError.getField(), cst.xy(15, r));
r += 2; // ----------------------------
JButton stopButton = new JButton(stopAction);
stopButton.setToolTipText("Stop the training of the evaluator");
JButton incTrainButton = new JButton(incrementalTrainAction);
incTrainButton.setToolTipText("Incrementally train the evaluator");
JButton lastButton = new JButton(lastAction);
lastButton.setToolTipText("Use the last weights");
builder.add(stopButton, cst.xy(3, r));
builder.add(incTrainButton, cst.xy(5, r));
builder.add(lastButton, cst.xy(7, r));
builder.add(trainIndex.getLabel(), cst.xy(9, r));
builder.add(trainIndex.getField(), cst.xy(11, r));
builder.add(trainError.getLabel(), cst.xy(13, r));
builder.add(trainError.getField(), cst.xy(15, r));
}
//---------------//
// displayParams //
//---------------//
private void displayParams ()
{
GlyphNetwork network = (GlyphNetwork) engine;
listEpochs.setValue(network.getListEpochs());
learningRate.setValue(network.getLearningRate());
momentum.setValue(network.getMomentum());
maxError.setValue(network.getMaxError());
}
//-------------//
// inputParams //
//-------------//
private void inputParams ()
{
GlyphNetwork network = (GlyphNetwork) engine;
network.setListEpochs(listEpochs.getValue());
network.setLearningRate(learningRate.getValue());
network.setMomentum(momentum.getValue());
network.setMaxError(maxError.getValue());
progressBar.setMaximum(network.getListEpochs());
}
//~ Inner Classes ----------------------------------------------------------
//------------//
// BestAction //
//------------//
private class BestAction
extends AbstractAction
{
//~ Constructors -------------------------------------------------------
public BestAction ()
{
super("Use Best");
}
//~ Methods ------------------------------------------------------------
@Override
public void actionPerformed (ActionEvent e)
{
GlyphNetwork glyphNetwork = (GlyphNetwork) engine;
NeuralNetwork network = glyphNetwork.getNetwork();
network.restore(bestSnap);
logger.info("Network remaining error : {}", (float) bestMse);
glyphNetwork.marshal();
// Let the user choose the other possibility
setEnabled(false);
lastAction.setEnabled(true);
}
}
//------------//
// LastAction //
//------------//
private class LastAction
extends AbstractAction
{
//~ Constructors -------------------------------------------------------
public LastAction ()
{
super("Use Last");
}
//~ Methods ------------------------------------------------------------
@Override
public void actionPerformed (ActionEvent e)
{
// Ask user confirmation if needed
if (lastMse > bestMse) {
final int answer = JOptionPane.showConfirmDialog(
component,
"Do you want to switch to this non-optimal network ?");
if (answer != JOptionPane.YES_OPTION) {
return;
}
}
GlyphNetwork glyphNetwork = (GlyphNetwork) engine;
NeuralNetwork network = glyphNetwork.getNetwork();
network.restore(lastSnap);
logger.info("Network remaining error : {}", (float) lastMse);
glyphNetwork.marshal();
// Let the user choose the other possibility
setEnabled(false);
bestAction.setEnabled(true);
}
}
//--------------------//
// NetworkTrainAction //
//--------------------//
private class NetworkTrainAction
extends TrainingPanel.TrainAction
{
//~ Constructors -------------------------------------------------------
public NetworkTrainAction (String title,
EvaluationEngine.StartingMode mode,
boolean confirmationRequired)
{
super(title);
this.mode = mode;
this.confirmationRequired = confirmationRequired;
}
//~ Methods ------------------------------------------------------------
//-------//
// train //
//-------//
@Override
public void train ()
{
super.train();
NeuralNetwork network = ((GlyphNetwork) engine).getNetwork();
lastSnap = network.backup();
// By default, keep the better between best recorded and last
if (lastMse <= bestMse) {
lastAction.actionPerformed(null);
} else {
bestAction.actionPerformed(null);
}
}
}
//-------------//
// ParamAction //
//-------------//
private class ParamAction
extends AbstractAction
{
//~ Methods ------------------------------------------------------------
// Purpose is just to read and remember the data from the various
// input fields. Triggered when user presses Enter in one of these
// fields.
@Override
public void actionPerformed (ActionEvent e)
{
inputParams();
displayParams();
}
}
//------------//
// StopAction //
//------------//
private class StopAction
extends AbstractAction
{
//~ Constructors -------------------------------------------------------
public StopAction ()
{
super("Stop");
}
//~ Methods ------------------------------------------------------------
@Override
public void actionPerformed (ActionEvent e)
{
engine.stop();
}
}
}