/* * Encog(tm) Workbench v3.4 * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-workbench * * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.workbench.tabs.proben; import java.awt.BorderLayout; import java.awt.Color; import java.awt.Font; import java.awt.FontMetrics; import java.awt.Graphics; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.util.ArrayList; import java.util.List; import javax.swing.JButton; import javax.swing.JPanel; import org.encog.app.analyst.AnalystError; import org.encog.mathutil.NumericRange; import org.encog.mathutil.error.ErrorCalculation; import org.encog.ml.MLError; import org.encog.ml.MLMethod; import org.encog.ml.factory.MLMethodFactory; import org.encog.ml.factory.MLTrainFactory; import org.encog.ml.train.MLTrain; import org.encog.ml.train.strategy.end.EarlyStoppingStrategy; import org.encog.ml.train.strategy.end.EndIterationsStrategy; import org.encog.util.Format; import org.encog.util.Stopwatch; import org.encog.workbench.EncogWorkBench; import org.encog.workbench.tabs.EncogCommonTab; import org.encog.workbench.util.EncogFonts; public class ProbenStatusTab extends EncogCommonTab implements Runnable, ActionListener { /** * The start button. */ private final JButton buttonStart; /** * The stop button. */ private final JButton buttonStopAll; /** * Stop the current command. */ private final JButton buttonStopCurrent; /** * The close button. */ private final JButton buttonClose; /** * The body of the dialog box is stored in this panel. */ private final JPanel panelBody; /** * The buttons are hold in this panel. */ private final JPanel panelButtons; /** * The background thread that processes training. */ private Thread thread; private boolean cancelCommand; private boolean cancelAll; /** * The font to use for headings. */ private Font headFont; /** * The font for body text. */ private Font bodyFont; private String status; private int currentDataset = 0; private String trainingError = ""; private String validationError = ""; private String testError = ""; private String trainingIterations = ""; private String currentTrainingRun = ""; private boolean shouldExit; private long lastUpdate; private Stopwatch totalTime = new Stopwatch(); private Stopwatch commandTime = new Stopwatch(); private MLTrain train; private ProBenFiles files = new ProBenFiles(); private String methodName; private String methodArchitecture; private String trainingName; private String trainingArgs; private ProBenData data; private int trainingRuns; private int maxIterations; private List<Double> listTrainingError = new ArrayList<Double>(); private List<Double> listValidationError = new ArrayList<Double>(); private List<Double> listTestError = new ArrayList<Double>(); private List<Double> listIterations = new ArrayList<Double>(); /** * Construct the dialog box. * * @param owner * The owner of the dialog box. */ public ProbenStatusTab( int theTrainingRuns, int theMaxIterations, String theMethodName, String theMethodArchitecture, String theTrainingName, String theTrainingArgs) { super(null); this.trainingRuns = theTrainingRuns; this.maxIterations = theMaxIterations; this.methodName = theMethodName; this.methodArchitecture = theMethodArchitecture; this.trainingName = theTrainingName; this.trainingArgs = theTrainingArgs; this.status = "Waiting to start."; this.buttonStart = new JButton("Start"); this.buttonStopAll = new JButton("Stop All Datasets"); this.buttonStopCurrent = new JButton("Stop Current Dataset"); this.buttonClose = new JButton("Close"); this.buttonStart.addActionListener(this); this.buttonStopAll.addActionListener(this); this.buttonClose.addActionListener(this); this.buttonStopCurrent.addActionListener(this); setLayout(new BorderLayout()); this.panelBody = new ProbenStatusPanel(this); this.panelButtons = new JPanel(); this.panelButtons.add(this.buttonStart); this.panelButtons.add(this.buttonStopAll); this.panelButtons.add(this.buttonStopCurrent); this.panelButtons.add(this.buttonClose); add(this.panelBody, BorderLayout.CENTER); add(this.panelButtons, BorderLayout.SOUTH); this.buttonStopAll.setEnabled(false); this.buttonStopCurrent.setEnabled(false); this.bodyFont = EncogFonts.getInstance().getBodyFont(); this.headFont = EncogFonts.getInstance().getHeadFont(); } private void performClose() { } /** * Track button presses. * * @param e * Event info. */ public void actionPerformed(final ActionEvent e) { if (e.getSource() == this.buttonClose) { dispose(); } else if (e.getSource() == this.buttonStart) { performStart(); } else if (e.getSource() == this.buttonStopAll) { performStopAll(); } else if (e.getSource() == this.buttonStopCurrent) { performStopCurrent(); } } public boolean close() { if (this.thread == null) { performClose(); return true; } else { this.shouldExit = true; this.cancelAll = true; return false; } } public void paintStatus(final Graphics g) { g.setColor(Color.white); final int width = getWidth(); final int height = getHeight(); g.fillRect(0, 0, width, height); g.setColor(Color.black); g.setFont(this.headFont); final FontMetrics fm = g.getFontMetrics(); int y = fm.getHeight(); g.drawString("Overall Status:", 10, y); y += fm.getHeight(); g.drawString("Total Datasets:", 10, y); y += fm.getHeight(); g.drawString("Current Dataset Name:", 10, y); y += fm.getHeight(); g.drawString("Current Dataset Number:", 10, y); y += fm.getHeight(); g.drawString("Max Iterations:", 10, y); y += fm.getHeight(); g.drawString("Training run:", 10, y); y = fm.getHeight(); g.drawString("Elapsed Time:", 350, y); y += fm.getHeight(); g.drawString("Command Elapsed Time:", 350, y); y += fm.getHeight(); g.drawString("Training Type:", 350, y); y += fm.getHeight(); g.drawString("Error Calc Type:", 350, y); y += fm.getHeight(); g.drawString("Training Iterations:", 350, y); y += fm.getHeight(); g.drawString("Training Error:", 350, y); y += fm.getHeight(); g.drawString("Validation Error:", 350, y); y += fm.getHeight(); g.drawString("Test Error:", 350, y); y = fm.getHeight(); g.setFont(this.bodyFont); g.drawString(this.status, 175, y); y += fm.getHeight(); g.drawString("" + this.files.getList().size(), 175, y); y += fm.getHeight(); g.drawString( this.currentDataset==0?"N/A":this.files.getList().get(this.currentDataset-1), 175, y); y += fm.getHeight(); g.drawString(this.currentDataset + " / " + this.files.getList().size(), 175, y); y += fm.getHeight(); g.drawString(Format.formatInteger(this.maxIterations), 175, y); y += fm.getHeight(); g.drawString(this.currentTrainingRun, 175, y); y += fm.getHeight(); String time1 = Format.formatTimeSpan((int)(this.totalTime.getElapsedMilliseconds()/1000)); String time2 = Format.formatTimeSpan((int)(this.commandTime.getElapsedMilliseconds()/1000)); y = fm.getHeight(); g.setFont(this.bodyFont); g.drawString(time1, 500, y); y += fm.getHeight(); g.drawString(time2, 500, y); y += fm.getHeight(); if( train!=null ) { g.drawString(train.getClass().getSimpleName(), 500, y); } y += fm.getHeight(); g.drawString(ErrorCalculation.getMode().toString(), 500, y); y += fm.getHeight(); g.drawString(this.trainingIterations, 500, y); y += fm.getHeight(); g.drawString(this.trainingError, 500, y); y += fm.getHeight(); g.drawString(this.validationError, 500, y); y += fm.getHeight(); g.drawString(this.testError, 500, y); } /** * Start the training. */ private void performStart() { this.buttonStart.setEnabled(false); this.buttonStopAll.setEnabled(true); this.buttonStopCurrent.setEnabled(true); this.cancelAll = false; this.cancelCommand = false; this.status = "Started"; this.thread = new Thread(this); this.thread.start(); } /** * Request that the training stop. */ private void performStopAll() { this.status = "Canceled"; this.cancelCommand = true; this.cancelAll = true; } private void performStopCurrent() { this.cancelCommand = true; } private void evaluate() { this.cancelCommand = false; this.commandTime.reset(); MLMethodFactory methodFactory = new MLMethodFactory(); MLMethod method = methodFactory.create(methodName, methodArchitecture, data.getInputCount(), data.getIdealCount()); MLTrainFactory trainFactory = new MLTrainFactory(); this.train = trainFactory.create(method, data.getTrainingDataSet(), trainingName, trainingArgs); train.addStrategy(new EndIterationsStrategy(this.maxIterations)); train.addStrategy(new EarlyStoppingStrategy(data.getValidationDataSet())); Stopwatch sw = new Stopwatch(); sw.start(); MLError calc = (MLError) train.getMethod(); int iterations = 0; do { train.iteration(); iterations++; if (sw.getElapsedMilliseconds() > 1000) { this.trainingError = Format.formatPercent(train.getError()); this.testError = Format.formatPercent(calc.calculateError(data .getTestDataSet())); this.validationError = Format.formatPercent(calc .calculateError(data.getValidationDataSet())); this.trainingIterations = Format.formatInteger(iterations); update(); sw.reset(); } } while (train.getError() > 0.01 && !this.shouldExit && !this.cancelCommand && !this.cancelAll && !train.isTrainingDone()); double trainError = calc.calculateError(data.getTrainingDataSet()); double testError = calc.calculateError(data.getTestDataSet()); double validationError = calc.calculateError(data .getValidationDataSet()); this.listTrainingError.add(trainError); this.listValidationError.add(validationError); this.listTestError.add(testError); this.listIterations.add((double)iterations); } /** * Process the background thread. Cycle through training iterations. If the * cancel flag is set, then exit. */ public void run() { try { this.status = "Running..."; this.totalTime.reset(); this.commandTime.reset(); this.totalTime.start(); update(); for(int i=0;i<this.files.getList().size()&&!this.shouldExit&&!this.cancelAll;i++) { this.listIterations.clear(); this.listTestError.clear(); this.listTrainingError.clear(); this.listValidationError.clear(); this.currentDataset = i+1; this.data = new ProBenData(this.files.getList().get(i)); this.data.load(); for(int r=0;r<this.trainingRuns;r++) { this.currentTrainingRun = (r+1) + "/" + this.trainingRuns; evaluate(); } writeResult(); update(); } EncogWorkBench.getInstance().getMainWindow().getTree().refresh(); } catch (AnalystError ex) { ex.printStackTrace(); EncogWorkBench.getInstance().outputLine("***Encog Analyst Error"); EncogWorkBench.getInstance().outputLine(ex.getMessage()); this.status = "Error encountered."; EncogWorkBench.getInstance().getMainWindow().getTree().refresh(); } catch (Throwable t) { EncogWorkBench.displayError("Error", t); EncogWorkBench.getInstance().outputLine("***Encog Analyst Exception"); EncogWorkBench.getInstance().outputLine(t.getMessage()); this.status = "Exception encountered."; EncogWorkBench.getInstance().getMainWindow().getTree().refresh(); dispose(); } finally { shutdown(); stopped(); this.status = "Done."; update(true); EncogWorkBench.getInstance().refresh(); if (this.shouldExit) { dispose(); } } } private String formatRange(NumericRange r) { StringBuilder result = new StringBuilder(); result.append(Format.formatDouble(r.getMean(), 2)); result.append(" (sdev="); result.append(Format.formatDouble(r.getStandardDeviation(),2)); result.append(")"); return result.toString(); } private String formatPercentRange(NumericRange r) { StringBuilder result = new StringBuilder(); result.append(Format.formatPercent(r.getMean())); result.append(" (sdev="); result.append(Format.formatPercent(r.getStandardDeviation())); result.append(")"); return result.toString(); } private void writeResult() { NumericRange rangeIterations = new NumericRange(this.listIterations); NumericRange rangeTest = new NumericRange(this.listTestError); NumericRange rangeValidation = new NumericRange(this.listValidationError); NumericRange rangeTraining = new NumericRange(this.listTrainingError); String str = data.getName() + "; Iterations=" + formatRange(rangeIterations) + "; Data Size=" + Format.formatInteger((int) data.getTrainingDataSet() .getRecordCount()) + "; Training Error=" + formatPercentRange(rangeTraining) + "; Validation Error=" + formatPercentRange(rangeValidation); EncogWorkBench.getInstance().outputLine(str); } /** * Implemented by subclasses to perform any shutdown after training. */ public void shutdown() { } /** * Implemented by subclasses to perform any activity before training. */ public void startup() { } /** * Called when training has stopped. */ private void stopped() { this.thread = null; this.buttonStart.setEnabled(true); this.buttonStopAll.setEnabled(false); this.buttonStopCurrent.setEnabled(false); this.cancelAll = true; } public void requestShutdown() { this.cancelAll = true; } public boolean shouldShutDown() { return this.cancelAll; } public void update() { update(false); } public void update(boolean force) { long now = System.currentTimeMillis(); if( (now-this.lastUpdate)>1000 || force ) { this.lastUpdate = now; repaint(); } } public void requestCancelCommand() { this.cancelCommand = true; } public boolean shouldStopCommand() { return this.cancelCommand; } @Override public String getName() { return "Proben1 Progress"; } /** * @return the files */ public ProBenFiles getFiles() { return files; } }