/*
* Encog(tm) Examples v2.4
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
*
* Copyright 2008-2010 by Heaton Research Inc.
*
* Released under the LGPL.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*
* Encog and Heaton Research are Trademarks of Heaton Research, Inc.
* For information on Heaton Research trademarks, visit:
*
* http://www.heatonresearch.com/copyright.html
*/
package org.encog.examples.neural.gui.ocr;
import java.awt.Font;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.OutputStream;
import java.io.PrintStream;
import java.text.NumberFormat;
import javax.swing.DefaultListModel;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JList;
import javax.swing.JOptionPane;
import javax.swing.JScrollPane;
import javax.swing.ScrollPaneConstants;
import javax.swing.SwingConstants;
import javax.swing.SwingUtilities;
import javax.swing.event.ListSelectionEvent;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.mathutil.rbf.GaussianFunction;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.data.basic.BasicNeuralData;
import org.encog.neural.data.basic.BasicNeuralDataPair;
import org.encog.neural.data.basic.BasicNeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.competitive.CompetitiveTraining;
import org.encog.neural.networks.training.competitive.neighborhood.NeighborhoodSingleRBF;
import org.encog.util.logging.Logging;
/**
* OCR: Main form that allows the user to interact with the OCR application.
*/
public class OCR extends JFrame implements Runnable {
class SymAction implements java.awt.event.ActionListener {
public void actionPerformed(final java.awt.event.ActionEvent event) {
final Object object = event.getSource();
if (object == OCR.this.downSample) {
downSample_actionPerformed(event);
} else if (object == OCR.this.clear) {
clear_actionPerformed(event);
} else if (object == OCR.this.add) {
add_actionPerformed(event);
} else if (object == OCR.this.del) {
del_actionPerformed(event);
} else if (object == OCR.this.load) {
load_actionPerformed(event);
} else if (object == OCR.this.save) {
save_actionPerformed(event);
} else if (object == OCR.this.train) {
train_actionPerformed(event);
} else if (object == OCR.this.recognize) {
recognize_actionPerformed(event);
}
}
}
class SymListSelection implements javax.swing.event.ListSelectionListener {
public void valueChanged(
final javax.swing.event.ListSelectionEvent event) {
final Object object = event.getSource();
if (object == OCR.this.letters) {
letters_valueChanged(event);
}
}
}
public class UpdateStats implements Runnable {
long tries;
double error;
public void run() {
OCR.this.tries.setText("" + this.tries);
OCR.this.txtError.setText(""
+ OCR.this.numberFormat.format(this.error));
}
}
/**
* Serial id for this class.
*/
private static final long serialVersionUID = -6779380961875907013L;
/**
* The downsample width for the application.
*/
static final int DOWNSAMPLE_WIDTH = 5;
/**
* The down sample height for the application.
*/
static final int DOWNSAMPLE_HEIGHT = 7;
/**
* The main method.
*
* @param args
* Args not really used.
*/
public static void main(final String args[]) {
Logging.stopConsoleLogging();
(new OCR()).setVisible(true);
}
private final NumberFormat numberFormat;
private boolean halt;
/**
* The entry component for the user to draw into.
*/
private final Entry entry;
/**
* The down sample component to display the drawing downsampled.
*/
private final Sample sample;
/**
* The letters that have been defined.
*/
private final DefaultListModel letterListModel = new DefaultListModel();
/**
* The neural network.
*/
private BasicNetwork net;
/**
* The background thread used for training.
*/
private Thread trainThread = null;
private final JLabel JLabel1 = new javax.swing.JLabel();
private final JLabel JLabel2 = new javax.swing.JLabel();
/**
* THe downsample button.
*/
private final JButton downSample = new JButton();
/**
* The add button.
*/
private final JButton add = new JButton();
/**
* The clear button
*/
private final JButton clear = new JButton();
/**
* The recognize button
*/
private final JButton recognize = new JButton();
private final JScrollPane JScrollPane1 = new JScrollPane();
/**
* The letters list box
*/
private final JList letters = new JList();
/**
* The delete button
*/
private final JButton del = new JButton();
/**
* The load button
*/
private final JButton load = new JButton();
/**
* The save button
*/
private final JButton save = new JButton();
/**
* The train button
*/
JButton train = new JButton();
JLabel JLabel3 = new JLabel();
JLabel JLabel4 = new JLabel();
/**
* How many tries
*/
JLabel tries = new JLabel();
/**
* The last error
*/
JLabel txtError = new JLabel();
JLabel JLabel8 = new JLabel();
JLabel JLabel5 = new JLabel();
/**
* The constructor.
*/
OCR() {
getContentPane().setLayout(null);
this.entry = new Entry();
this.entry.setLocation(168, 25);
this.entry.setSize(200, 128);
getContentPane().add(this.entry);
this.sample = new Sample(OCR.DOWNSAMPLE_WIDTH, OCR.DOWNSAMPLE_HEIGHT);
this.sample.setLocation(307, 210);
this.sample.setSize(65, 70);
this.entry.setSample(this.sample);
getContentPane().add(this.sample);
setTitle("Java Neural Network");
getContentPane().setLayout(null);
setSize(405, 382);
setVisible(false);
this.JLabel1.setText("Letters Known");
getContentPane().add(this.JLabel1);
this.JLabel1.setBounds(12, 12, 100, 12);
this.JLabel2.setText("Tries:");
getContentPane().add(this.JLabel2);
this.JLabel2.setBounds(12, 264, 72, 24);
this.downSample.setText("D Sample");
this.downSample.setActionCommand("Down Sample");
getContentPane().add(this.downSample);
this.downSample.setBounds(252, 180, 120, 24);
this.add.setText("Add");
this.add.setActionCommand("Add");
getContentPane().add(this.add);
this.add.setBounds(168, 156, 84, 24);
this.clear.setText("Clear");
this.clear.setActionCommand("Clear");
getContentPane().add(this.clear);
this.clear.setBounds(168, 180, 84, 24);
this.recognize.setText("Recognize");
this.recognize.setActionCommand("Recognize");
getContentPane().add(this.recognize);
this.recognize.setBounds(252, 156, 120, 24);
this.JScrollPane1
.setVerticalScrollBarPolicy(ScrollPaneConstants.VERTICAL_SCROLLBAR_ALWAYS);
this.JScrollPane1.setOpaque(true);
getContentPane().add(this.JScrollPane1);
this.JScrollPane1.setBounds(12, 24, 144, 132);
this.JScrollPane1.getViewport().add(this.letters);
this.letters.setBounds(0, 0, 126, 129);
this.del.setText("Delete");
this.del.setActionCommand("Delete");
getContentPane().add(this.del);
this.del.setBounds(12, 156, 144, 24);
this.load.setText("Load");
this.load.setActionCommand("Load");
getContentPane().add(this.load);
this.load.setBounds(12, 180, 75, 24);
this.save.setText("Save");
this.save.setActionCommand("Save");
getContentPane().add(this.save);
this.save.setBounds(84, 180, 72, 24);
this.train.setText("Begin Training");
this.train.setActionCommand("Begin Training");
getContentPane().add(this.train);
this.train.setBounds(12, 204, 144, 24);
this.JLabel3.setText("Error:");
getContentPane().add(this.JLabel3);
this.JLabel3.setBounds(12, 288, 72, 24);
this.tries.setText("0");
getContentPane().add(this.tries);
this.tries.setBounds(96, 264, 72, 24);
this.txtError.setText("0");
getContentPane().add(this.txtError);
this.txtError.setBounds(96, 288, 72, 24);
this.JLabel8.setHorizontalTextPosition(SwingConstants.CENTER);
this.JLabel8.setHorizontalAlignment(SwingConstants.CENTER);
this.JLabel8.setText("Training Results");
getContentPane().add(this.JLabel8);
this.JLabel8.setFont(new Font("Dialog", Font.BOLD, 14));
this.JLabel8.setBounds(12, 240, 120, 24);
this.JLabel5.setText("Draw Letters Here");
getContentPane().add(this.JLabel5);
this.JLabel5.setBounds(204, 12, 144, 12);
final SymAction lSymAction = new SymAction();
this.downSample.addActionListener(lSymAction);
this.clear.addActionListener(lSymAction);
this.add.addActionListener(lSymAction);
this.del.addActionListener(lSymAction);
final SymListSelection lSymListSelection = new SymListSelection();
this.letters.addListSelectionListener(lSymListSelection);
this.load.addActionListener(lSymAction);
this.save.addActionListener(lSymAction);
this.train.addActionListener(lSymAction);
this.recognize.addActionListener(lSymAction);
this.letters.setModel(this.letterListModel);
this.setDefaultCloseOperation(EXIT_ON_CLOSE);
this.numberFormat = NumberFormat.getNumberInstance();
}
/**
* Called to add the current image to the training set
*
* @param event
* The event
*/
@SuppressWarnings("unchecked")
void add_actionPerformed(final java.awt.event.ActionEvent event) {
int i;
final String letter = JOptionPane
.showInputDialog("Please enter a letter you would like to assign this sample to.");
if (letter == null) {
return;
}
if (letter.length() > 1) {
JOptionPane.showMessageDialog(this,
"Please enter only a single letter.", "Error",
JOptionPane.ERROR_MESSAGE);
return;
}
this.entry.downSample();
final SampleData sampleData = (SampleData) this.sample.getData()
.clone();
sampleData.setLetter(letter.charAt(0));
for (i = 0; i < this.letterListModel.size(); i++) {
final Comparable str = (Comparable) this.letterListModel
.getElementAt(i);
if (str.equals(letter)) {
JOptionPane.showMessageDialog(this,
"That letter is already defined, delete it first!",
"Error", JOptionPane.ERROR_MESSAGE);
return;
}
if (str.compareTo(sampleData) > 0) {
this.letterListModel.add(i, sampleData);
return;
}
}
this.letterListModel.add(this.letterListModel.size(), sampleData);
this.letters.setSelectedIndex(i);
this.entry.clear();
this.sample.repaint();
}
/**
* Called to clear the image.
*
* @param event
* The event
*/
void clear_actionPerformed(final java.awt.event.ActionEvent event) {
this.entry.clear();
this.sample.getData().clear();
this.sample.repaint();
}
/**
* Called when the del button is pressed.
*
* @param event
* The event.
*/
void del_actionPerformed(final java.awt.event.ActionEvent event) {
final int i = this.letters.getSelectedIndex();
if (i == -1) {
JOptionPane.showMessageDialog(this,
"Please select a letter to delete.", "Error",
JOptionPane.ERROR_MESSAGE);
return;
}
this.letterListModel.remove(i);
}
/**
* Called to downsample the image.
*
* @param event
* The event
*/
void downSample_actionPerformed(final java.awt.event.ActionEvent event) {
this.entry.downSample();
}
/**
* Called when a letter is selected from the list box.
*
* @param event
* The event
*/
void letters_valueChanged(final ListSelectionEvent event) {
if (this.letters.getSelectedIndex() == -1) {
return;
}
final SampleData selected = (SampleData) this.letterListModel
.getElementAt(this.letters.getSelectedIndex());
this.sample.setData((SampleData) selected.clone());
this.sample.repaint();
this.entry.clear();
}
/**
* Called when the load button is pressed.
*
* @param event
* The event
*/
void load_actionPerformed(final java.awt.event.ActionEvent event) {
try {
FileReader f;// the actual file stream
BufferedReader r;// used to read the file line by line
f = new FileReader(new File("./sample.dat"));
r = new BufferedReader(f);
String line;
int i = 0;
this.letterListModel.clear();
while ((line = r.readLine()) != null) {
final SampleData ds = new SampleData(line.charAt(0),
OCR.DOWNSAMPLE_WIDTH, OCR.DOWNSAMPLE_HEIGHT);
this.letterListModel.add(i++, ds);
int idx = 2;
for (int y = 0; y < ds.getHeight(); y++) {
for (int x = 0; x < ds.getWidth(); x++) {
ds.setData(x, y, line.charAt(idx++) == '1');
}
}
}
r.close();
f.close();
clear_actionPerformed(null);
JOptionPane.showMessageDialog(this, "Loaded from 'sample.dat'.",
"Training", JOptionPane.PLAIN_MESSAGE);
} catch (final Exception e) {
e.printStackTrace();
JOptionPane.showMessageDialog(this, "Error: " + e, "Training",
JOptionPane.ERROR_MESSAGE);
}
}
/**
* Used to map neurons to actual letters.
*
* @return The current mapping between neurons and letters as an array.
*/
char[] mapNeurons() {
final char map[] = new char[this.letterListModel.size()];
for (int i = 0; i < map.length; i++) {
map[i] = '?';
}
for (int i = 0; i < this.letterListModel.size(); i++) {
final NeuralData input = new BasicNeuralData(5 * 7);
int idx = 0;
final SampleData ds = (SampleData) this.letterListModel
.getElementAt(i);
for (int y = 0; y < ds.getHeight(); y++) {
for (int x = 0; x < ds.getWidth(); x++) {
input.setData(idx++, ds.getData(x, y) ? .5 : -.5);
}
}
final int best = this.net.winner(input);
map[best] = ds.getLetter();
}
return map;
}
public void markStopped() {
this.trainThread = null;
this.train.setText("Begin Training");
JOptionPane.showMessageDialog(this, "Training has completed.",
"Training", JOptionPane.PLAIN_MESSAGE);
}
/**
* Called when the recognize button is pressed.
*
* @param event
* The event.
*/
void recognize_actionPerformed(final java.awt.event.ActionEvent event) {
if (this.net == null) {
JOptionPane.showMessageDialog(this, "I need to be trained first!",
"Error", JOptionPane.ERROR_MESSAGE);
return;
}
this.entry.downSample();
final NeuralData input = new BasicNeuralData(5 * 7);
int idx = 0;
final SampleData ds = this.sample.getData();
for (int y = 0; y < ds.getHeight(); y++) {
for (int x = 0; x < ds.getWidth(); x++) {
input.setData(idx++, ds.getData(x, y) ? .5 : -.5);
}
}
final int best = this.net.winner(input);
final char map[] = mapNeurons();
JOptionPane
.showMessageDialog(this, " " + map[best] + " (Neuron #"
+ best + " fired)", "That Letter Is",
JOptionPane.PLAIN_MESSAGE);
clear_actionPerformed(null);
}
/**
* Run method for the background training thread.
*/
public void run() {
try {
final int inputNeuron = OCR.DOWNSAMPLE_HEIGHT
* OCR.DOWNSAMPLE_WIDTH;
final int outputNeuron = this.letterListModel.size();
final NeuralDataSet trainingSet = new BasicNeuralDataSet();
for (int t = 0; t < this.letterListModel.size(); t++) {
final NeuralData item = new BasicNeuralData(inputNeuron);
int idx = 0;
final SampleData ds = (SampleData) this.letterListModel
.getElementAt(t);
for (int y = 0; y < ds.getHeight(); y++) {
for (int x = 0; x < ds.getWidth(); x++) {
item.setData(idx++, ds.getData(x, y) ? .5 : -.5);
}
}
trainingSet.add(new BasicNeuralDataPair(item, null));
}
this.net = new BasicNetwork();
this.net.addLayer(new BasicLayer(new ActivationLinear(), false,
inputNeuron));
this.net.addLayer(new BasicLayer(new ActivationLinear(), false,
outputNeuron));
this.net.getStructure().finalizeStructure();
this.net.reset();
final CompetitiveTraining train = new CompetitiveTraining(this.net,
0.25, trainingSet, new NeighborhoodSingleRBF(
new GaussianFunction(0, 1, 2)));
train.setForceWinner(true);
int tries = 1;
while (!this.halt) {
train.iteration();
update(tries++, train.getError());
}
markStopped();
this.halt = false;
} catch (final Exception e) {
e.printStackTrace();
JOptionPane.showMessageDialog(this, "Error: " + e, "Training",
JOptionPane.ERROR_MESSAGE);
}
}
/**
* Called when the save button is clicked.
*
* @param event
* The event
*/
void save_actionPerformed(final java.awt.event.ActionEvent event) {
try {
OutputStream os;// the actual file stream
PrintStream ps;// used to read the file line by line
os = new FileOutputStream("./sample.dat", false);
ps = new PrintStream(os);
for (int i = 0; i < this.letterListModel.size(); i++) {
final SampleData ds = (SampleData) this.letterListModel
.elementAt(i);
ps.print(ds.getLetter() + ":");
for (int y = 0; y < ds.getHeight(); y++) {
for (int x = 0; x < ds.getWidth(); x++) {
ps.print(ds.getData(x, y) ? "1" : "0");
}
}
ps.println("");
}
ps.close();
os.close();
clear_actionPerformed(null);
JOptionPane.showMessageDialog(this, "Saved to 'sample.dat'.",
"Training", JOptionPane.PLAIN_MESSAGE);
} catch (final Exception e) {
e.printStackTrace();
JOptionPane.showMessageDialog(this, "Error: " + e, "Training",
JOptionPane.ERROR_MESSAGE);
}
}
/**
* Called when the train button is pressed.
*
* @param event
* The event.
*/
void train_actionPerformed(final java.awt.event.ActionEvent event) {
if (this.trainThread == null) {
this.train.setText("Stop Training");
this.train.repaint();
this.trainThread = new Thread(this);
this.trainThread.start();
} else {
this.halt = true;
}
}
/**
* Called to update the stats, from the neural network.
*
* @param trial
* How many tries.
* @param error
* The current error.
*/
public void update(final int retry, final double error) {
final UpdateStats stats = new UpdateStats();
stats.tries = retry;
stats.error = error;
try {
SwingUtilities.invokeAndWait(stats);
} catch (final Exception e) {
JOptionPane.showMessageDialog(this, "Error: " + e, "Training",
JOptionPane.ERROR_MESSAGE);
}
}
}