//----------------------------------------------------------------------------//
// //
// T r a i n i n g P a n e l //
// //
//----------------------------------------------------------------------------//
// <editor-fold defaultstate="collapsed" desc="hdr"> //
// Copyright © Hervé Bitteur and others 2000-2013. All rights reserved. //
// This software is released under the GNU General Public License. //
// Goto http://kenai.com/projects/audiveris to report bugs or suggestions. //
//----------------------------------------------------------------------------//
// </editor-fold>
package omr.glyph.ui.panel;
import omr.glyph.*;
import omr.glyph.GlyphNetwork;
import omr.glyph.GlyphRepository;
import omr.glyph.Shape;
import static omr.glyph.Shape.*;
import omr.glyph.facets.Glyph;
import static omr.glyph.ui.panel.GlyphTrainer.Task.Activity.*;
import omr.ui.util.Panel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.jgoodies.forms.builder.PanelBuilder;
import com.jgoodies.forms.layout.CellConstraints;
import com.jgoodies.forms.layout.FormLayout;
import java.awt.event.ActionEvent;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Observable;
import java.util.Observer;
import javax.swing.AbstractAction;
import javax.swing.ButtonGroup;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JOptionPane;
import javax.swing.JProgressBar;
import javax.swing.JRadioButton;
import javax.swing.SwingWorker;
/**
* Class {@code TrainingPanel} is a panel dedicated to the training of
* an evaluation engine.
* It is used through its subclasses {@link NetworkPanel} and {@link
* RegressionPanel} to train the neural network engine and the linear
* engine respectively. It is a dedicated companion of class {@link
* GlyphTrainer}.
*
* @author Hervé Bitteur
*/
class TrainingPanel
implements EvaluationEngine.Monitor, Observer
{
//~ Static fields/initializers ---------------------------------------------
/** Usual logger utility */
private static final Logger logger = LoggerFactory.getLogger(TrainingPanel.class);
//~ Instance fields --------------------------------------------------------
/** The swing component */
protected final Panel component;
/** Current activity (selecting the population, or training the engine on
* the selected population */
protected final GlyphTrainer.Task task;
/** User action to launch the training */
protected TrainAction trainAction;
/** The underlying engine to be trained */
protected EvaluationEngine engine;
/** User progress bar to visualize the training process */
protected JProgressBar progressBar = new JProgressBar();
/** Common JGoodies constraints for this class and its subclass if any */
protected CellConstraints cst = new CellConstraints();
/** Common JGoodies builder for this class and its subclass if any */
protected PanelBuilder builder;
/** Repository of known glyphs */
private final GlyphRepository repository = GlyphRepository.getInstance();
/**
* Flag to indicate that the whole population of recorded glyphs (and not
* just the core ones) is to be considered
*/
private boolean useWhole = true;
/** Display of cardinality of whole population */
private JLabel wholeNumber = new JLabel();
/** Display of cardinality of core population */
private JLabel coreNumber = new JLabel();
/** UI panel dealing with repository selection */
private final SelectionPanel selectionPanel;
/** The Neural Network engine */
private GlyphNetwork network = GlyphNetwork.getInstance();
//~ Constructors -----------------------------------------------------------
//---------------//
// TrainingPanel //
//---------------//
/**
* Creates a new TrainingPanel object.
*
* @param task the current training task
* @param standardWidth standard width for fields & buttons
* @param engine the underlying engine to train
* @param selectionPanel user panel for glyphs selection
* @param totalRows total number of display rows, interlines not
* counted
*/
public TrainingPanel (GlyphTrainer.Task task,
String standardWidth,
EvaluationEngine engine,
SelectionPanel selectionPanel,
int totalRows)
{
this.engine = engine;
this.task = task;
this.selectionPanel = selectionPanel;
component = new Panel();
component.setNoInsets();
FormLayout layout = Panel.makeFormLayout(
totalRows,
4,
"",
standardWidth,
standardWidth);
builder = new PanelBuilder(layout, component);
builder.setDefaultDialogBorder(); // Useful ?
defineLayout();
}
//~ Methods ----------------------------------------------------------------
@Override
public void epochEnded (int epochIndex,
double mse)
{
}
//--------------//
// getComponent //
//--------------//
/**
* Give access to the encapsulated swing component
*
* @return the user panel
*/
public JComponent getComponent ()
{
return component;
}
@Override
public void glyphProcessed (final Glyph glyph)
{
}
@Override
public void trainingStarted (final int epochIndex,
final double mse)
{
}
//--------//
// update //
//--------//
/**
* Method triggered by new task activity : the train action is enabled only
* when no activity is going on.
*
* @param obs the task object
* @param unused not used
*/
@Override
public void update (Observable obs,
Object unused)
{
switch (task.getActivity()) {
case INACTIVE :
trainAction.setEnabled(true);
break;
case SELECTING :
case TRAINING :
trainAction.setEnabled(false);
break;
}
}
//----------//
// useWhole //
//----------//
/**
* Tell whether the whole glyph base is to be used, or just the core base
*
* @return true if whole, false if core
*/
public boolean useWhole ()
{
return useWhole;
}
//--------------//
// defineLayout //
//--------------//
/**
* Define the common part of the layout, each subclass being able to augment
* this layout from its constructor
*/
protected void defineLayout ()
{
// Buttons to select just the core glyphs, or the whole population
CoreAction coreAction = new CoreAction();
JRadioButton coreButton = new JRadioButton(coreAction);
WholeAction wholeAction = new WholeAction();
JRadioButton wholeButton = new JRadioButton(wholeAction);
// Group the radio buttons.
ButtonGroup group = new ButtonGroup();
group.add(wholeButton);
wholeButton.setToolTipText("Use the whole glyph base for any action");
group.add(coreButton);
coreButton.setToolTipText(
"Use only the core glyph base for any action");
wholeButton.setSelected(true);
// Evaluator Title & Progress Bar
int r = 1; // ----------------------------
String title = engine.getName() + " Training";
builder.addSeparator(title, cst.xyw(1, r, 7));
builder.add(progressBar, cst.xyw(9, r, 7));
r += 2; // ----------------------------
builder.add(wholeButton, cst.xy(3, r));
builder.add(wholeNumber, cst.xy(5, r));
r += 2; // ----------------------------
builder.add(coreButton, cst.xy(3, r));
builder.add(coreNumber, cst.xy(5, r));
// Initialize with population cardinalities
coreAction.actionPerformed(null);
wholeAction.actionPerformed(null);
}
//-----------------//
// checkPopulation //
//-----------------//
private void checkPopulation (List<Glyph> glyphs)
{
// Check that all trainable shapes are present in the training
// population and that only legal shapes are present. If illegal
// (non trainable) shapes are found, they are removed from the
// population.
boolean[] present = new boolean[LAST_PHYSICAL_SHAPE.ordinal() + 1];
Arrays.fill(present, false);
for (Iterator<Glyph> it = glyphs.iterator(); it.hasNext();) {
Glyph glyph = it.next();
Shape shape = glyph.getShape();
try {
Shape physicalShape = shape.getPhysicalShape();
if (physicalShape.isTrainable()) {
present[physicalShape.ordinal()] = true;
} else {
logger.warn("Removing non trainable shape:{}", physicalShape);
it.remove();
}
} catch (Exception ex) {
logger.warn("Removing weird shape: " + shape, ex);
it.remove();
}
}
for (int i = 0; i < present.length; i++) {
if (!present[i]) {
logger.warn("Missing shape: {}", Shape.values()[i]);
}
}
}
//~ Inner Classes ----------------------------------------------------------
//------------//
// DumpAction //
//------------//
protected class DumpAction
extends AbstractAction
{
//~ Constructors -------------------------------------------------------
public DumpAction ()
{
super("Dump");
}
//~ Methods ------------------------------------------------------------
@Override
public void actionPerformed (ActionEvent e)
{
engine.dump();
}
}
//-------------//
// TrainAction //
//-------------//
protected class TrainAction
extends AbstractAction
{
//~ Instance fields ----------------------------------------------------
// Specific training starting mode
protected EvaluationEngine.StartingMode mode = EvaluationEngine.StartingMode.SCRATCH;
protected boolean confirmationRequired = true;
//~ Constructors -------------------------------------------------------
public TrainAction (String title)
{
super(title);
}
//~ Methods ------------------------------------------------------------
@Override
public void actionPerformed (ActionEvent e)
{
// Ask user confirmation
if (confirmationRequired) {
int answer = JOptionPane.showConfirmDialog(
component,
"Do you really want to retrain " + engine.getName()
+ " from scratch?");
if (answer != JOptionPane.YES_OPTION) {
return;
}
}
class Worker
extends Thread
{
@Override
public void run ()
{
train();
}
}
Worker worker = new Worker();
worker.setPriority(Thread.MIN_PRIORITY);
worker.start();
}
//-------//
// train //
//-------//
public void train ()
{
task.setActivity(TRAINING);
Collection<String> gNames = selectionPanel.getBase(useWhole);
progressBar.setValue(0);
progressBar.setMaximum(network.getListEpochs());
List<Glyph> glyphs = new ArrayList<>();
for (String gName : gNames) {
Glyph glyph = repository.getGlyph(gName, selectionPanel);
if (glyph != null) {
if (glyph.getShape() != null) {
glyphs.add(glyph);
} else {
logger.warn("Cannot infer shape from {}", gName);
}
} else {
logger.warn("Cannot get glyph {}", gName);
}
}
// Check that all trainable shapes (and only those ones) are
// present in the training population
checkPopulation(glyphs);
engine.train(glyphs, TrainingPanel.this, mode);
task.setActivity(INACTIVE);
}
}
//------------//
// CoreAction //
//------------//
private class CoreAction
extends AbstractAction
{
//~ Instance fields ----------------------------------------------------
final SwingWorker<Integer, Object> worker = new SwingWorker<Integer, Object>() {
@Override
public void done ()
{
try {
coreNumber.setText("" + get());
} catch (Exception ex) {
logger.warn("Error while loading core base", ex);
}
}
@Override
protected Integer doInBackground ()
{
return selectionPanel.getBase(false)
.size();
}
};
//~ Constructors -------------------------------------------------------
public CoreAction ()
{
super("Core");
}
//~ Methods ------------------------------------------------------------
@Override
public void actionPerformed (ActionEvent e)
{
useWhole = false;
worker.execute();
}
}
//-------------//
// WholeAction //
//-------------//
private class WholeAction
extends AbstractAction
{
//~ Instance fields ----------------------------------------------------
final SwingWorker<Integer, Object> worker = new SwingWorker<Integer, Object>() {
@Override
public void done ()
{
try {
wholeNumber.setText("" + get());
} catch (Exception ex) {
logger.warn("Error while loading whole base", ex);
}
}
@Override
protected Integer doInBackground ()
{
return selectionPanel.getBase(true)
.size();
}
};
//~ Constructors -------------------------------------------------------
public WholeAction ()
{
super("Whole");
}
//~ Methods ------------------------------------------------------------
@Override
public void actionPerformed (ActionEvent e)
{
useWhole = true;
worker.execute();
}
}
}