/* * Encog(tm) Workbench v3.4 * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-workbench * * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.workbench.tabs.incremental; import java.awt.BorderLayout; import java.awt.Color; import java.awt.Font; import java.awt.FontMetrics; import java.awt.Graphics; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.io.File; import java.util.Date; import javax.swing.JButton; import javax.swing.JPanel; import org.encog.StatusReportable; import org.encog.ml.data.MLDataSet; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.pattern.FeedForwardPattern; import org.encog.neural.prune.PruneIncremental; import org.encog.persist.EncogDirectoryPersistence; import org.encog.util.Format; import org.encog.workbench.EncogWorkBench; import org.encog.workbench.tabs.EncogCommonTab; import org.encog.workbench.util.EncogFonts; public class IncrementalPruneTab extends EncogCommonTab implements ActionListener, Runnable, StatusReportable { /** * The start button. */ private final JButton buttonStart; /** * The stop button. */ private final JButton buttonStop; /** * The close button. */ private final JButton buttonClose; /** * The body of the dialog box is stored in this panel. */ private final JPanel panelBody; /** * The buttons are hold in this panel. */ private final JPanel panelButtons; /** * The background thread that processes training. */ private Thread thread; /** * Has training been canceled. */ private boolean cancel; /** * When was training started. */ private Date started; private String status; private JPanel statusPanel; private JPanel chartPanel; private int total; private int current; private double low; private double high; private final File path; /** * The font to use for headings. */ private Font headFont; /** * The font for body text. */ private Font bodyFont; /** * Should the dialog box exit? Are we waiting for training to shut down * first. */ private boolean shouldExit; private PruneIncremental prune; private int iterations; private int weightTries; private MLDataSet training; private FeedForwardPattern pattern; private int windowSize; public IncrementalPruneTab(int iterations, int weightTries, int windowSize, MLDataSet training, FeedForwardPattern pattern, File path) { super(null); this.weightTries = weightTries; this.iterations = iterations; this.training = training; this.pattern = pattern; this.windowSize = windowSize; this.prune = new PruneIncremental(this.training, this.pattern, this.iterations, this.weightTries , this.windowSize, this); this.prune.init(); this.buttonStart = new JButton("Start"); this.buttonStop = new JButton("Stop"); this.buttonClose = new JButton("Close"); this.buttonStart.addActionListener(this); this.buttonStop.addActionListener(this); this.buttonClose.addActionListener(this); this.path = path; setLayout(new BorderLayout()); this.panelBody = new JPanel(); this.panelButtons = new JPanel(); this.panelButtons.add(this.buttonStart); this.panelButtons.add(this.buttonStop); this.panelButtons.add(this.buttonClose); add(this.panelBody, BorderLayout.CENTER); add(this.panelButtons, BorderLayout.SOUTH); this.panelBody.setLayout(new BorderLayout()); this.panelBody.add(this.statusPanel = new IncrementalPruneStatusPanel( this), BorderLayout.NORTH); this.panelBody.add(this.chartPanel = new IncrementalPruneChart(this), BorderLayout.CENTER); this.buttonStop.setEnabled(false); this.shouldExit = false; this.bodyFont = EncogFonts.getInstance().getBodyFont(); this.headFont = EncogFonts.getInstance().getHeadFont(); this.status = "Ready to Start"; } /** * Track button presses. * * @param e * Event info. */ public void actionPerformed(final ActionEvent e) { if (e.getSource() == this.buttonClose) { dispose(); } else if (e.getSource() == this.buttonStart) { performStart(); } else if (e.getSource() == this.buttonStop) { performStop(); } } /** * Start the training. */ private void performStart() { this.started = new Date(); this.buttonStart.setEnabled(false); this.buttonStop.setEnabled(true); this.cancel = false; this.status = "Started"; repaint(); this.thread = new Thread(this); this.thread.start(); } /** * Request that the training stop. */ private void performStop() { this.buttonStop.setEnabled(false); this.status = "Canceled"; this.cancel = true; this.repaint(); this.prune.stop(); } public void run() { try { this.prune.process(); this.buttonStart.setEnabled(false); this.buttonStop.setEnabled(false); this.thread = null; if (this.shouldExit) { dispose(); } } catch(Throwable t) { EncogWorkBench.displayError("Error", t); } } public void paintStatus(Graphics g) { g.setColor(Color.white); final int width = getWidth(); final int height = getHeight(); g.fillRect(0, 0, width, height); g.setColor(Color.black); g.setFont(this.headFont); final FontMetrics fm = g.getFontMetrics(); int y = fm.getHeight(); g.drawString("Progress:", 10, y); y += fm.getHeight(); g.drawString("Percent Complete:", 10, y); y += fm.getHeight(); g.drawString("Status:", 10, y); y = fm.getHeight(); g.drawString("High Error:", 250, y); y += fm.getHeight(); g.drawString("Low Error:", 250, y); y = fm.getHeight(); g.drawString("Iterations to Try:", 450, y); y += fm.getHeight(); g.drawString("Weights to Try:", 450, y); g.setFont(this.bodyFont); StringBuilder progress = new StringBuilder(); if (this.total > 0) { progress.append(Format.formatInteger(this.current)); progress.append(" of "); progress.append(Format.formatInteger(this.total)); } double percent = 0; if (this.total > 0) percent = (double) this.current / (double) this.total; y = fm.getHeight(); g.drawString(progress.toString(), 150, y); y += fm.getHeight(); g.drawString(Format.formatPercent(percent), 150, y); y += fm.getHeight(); g.drawString(this.status, 150, y); y = fm.getHeight(); g.drawString(Format.formatPercent(this.high), 350, y); y += fm.getHeight(); g.drawString(Format.formatPercent(this.low), 350, y); y = fm.getHeight(); g.drawString(Format.formatInteger(this.iterations), 550, y); y += fm.getHeight(); g.drawString(Format.formatInteger(this.weightTries), 550, y); } public void paintChart(Graphics g, int width, int height) { g.setColor(Color.black); this.high = this.prune.getHigh(); this.low = this.prune.getLow(); if( this.prune.getHidden1Size()==0 && this.prune.getHidden2Size()==0 ) { g.drawString("Chart not supported for more than 2 layers.", 0, 20); } else if( this.prune.getHidden1Size()>0 ) { int blockWidth = (this.prune.getHidden2Size()>0) ? (width-32)/this.prune.getHidden2Size() : (width-32); int blockHeight = (height-32)/this.prune.getHidden1Size(); g.setFont(this.headFont); g.drawString("H1", 10, height/2); g.drawString(""+this.prune.getHidden().get(0).getMin(), 10, 42); if( this.prune.getHidden().size()>1 ) { g.drawString("H2", width/2, 15); g.drawString(""+this.prune.getHidden().get(1).getMin(), 32, 15); } int xLimit = Math.max(this.prune.getHidden2Size(),1); for(int y=0;y<this.prune.getHidden1Size();y++) { for(int x=0;x<xLimit;x++) { int xLoc = x*blockWidth; int yLoc = y*blockHeight; double error = this.prune.getResults()[y][x]; if( error>0.00001 ) { high = Math.max(high,error); low = Math.min(low,error); double range = high - low; double p = (error-low)/range; int c = (int)(p*255.0); g.setColor(new Color(c,c,c)); g.fillRect(32+xLoc, 32+yLoc, blockWidth, blockHeight); } else { g.setColor(Color.black); g.drawRect(32+xLoc, 32+yLoc, blockWidth, blockHeight); } } } } } public void report(int total, int current, String message) { this.total = total; this.current = current; this.status = message; repaint(); } public int getIterations() { return iterations; } public void setIterations(int iterations) { this.iterations = iterations; } public MLDataSet getTraining() { return training; } public void setTraining(MLDataSet training) { this.training = training; } public FeedForwardPattern getPattern() { return pattern; } public void setPattern(FeedForwardPattern pattern) { this.pattern = pattern; } public boolean close() { if (this.thread == null) { performClose(); return true; } else { this.shouldExit = true; this.cancel = true; return false; } } public void performClose() { if (this.prune != null) { BasicNetwork network = this.prune.getBestNetwork(); if (network != null) { if (EncogWorkBench.askQuestion("Network", "Do you wish to save this network?")) { if (network != null) { EncogDirectoryPersistence.saveObject(this.path,network); EncogWorkBench.getInstance().refresh(); } } } } } public void addHiddenRange(int low, int high) { this.prune.addHiddenLayer(low, high); this.prune.init(); this.repaint(); } public void reportPhase(int arg0, int arg1, String arg2) { // TODO Auto-generated method stub } @Override public String getName() { return "Prune Progress"; } }