/*
* Encog(tm) Examples v2.4
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
*
* Copyright 2008-2010 by Heaton Research Inc.
*
* Released under the LGPL.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*
* Encog and Heaton Research are Trademarks of Heaton Research, Inc.
* For information on Heaton Research trademarks, visit:
*
* http://www.heatonresearch.com/copyright.html
*/
package org.encog.examples.neural.gui.mpg;
import java.awt.BorderLayout;
import java.awt.Container;
import java.awt.GridLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.File;
import java.util.List;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JOptionPane;
import javax.swing.JPanel;
import javax.swing.JTextField;
import org.encog.ConsoleStatusReportable;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.normalize.DataNormalization;
import org.encog.normalize.input.BasicInputField;
import org.encog.normalize.input.InputField;
import org.encog.normalize.input.InputFieldCSV;
import org.encog.normalize.output.OutputField;
import org.encog.normalize.output.OutputFieldRangeMapped;
import org.encog.normalize.target.NormalizationStorageNeuralDataSet;
import org.encog.persist.EncogPersistedCollection;
import org.encog.persist.EncogPersistedObject;
import org.encog.util.simple.EncogUtility;
public class MilesPerGallon extends JFrame implements ActionListener, Runnable {
private JButton buttonGenerate;
private JButton buttonTrain;
private JButton buttonEvaluate;
private JTextField textPath;
private JTextField textHP;
private JTextField textDisp;
private JTextField textWeight;
private JTextField textCyl;
private JTextField textAccel;
private File fileCSV;
private File directory;
private File encogFile;
private boolean training;
private CalculateMPG calc;
JLabel labelMPG;
public MilesPerGallon()
{
setTitle("Neural Network Miles Per Gallon");
this.setSize(320,200);
Container content = this.getContentPane();
JPanel buttonPanel = new JPanel();
buttonPanel.add(buttonGenerate = new JButton("Generate"));
buttonPanel.add(buttonTrain = new JButton("Train"));
buttonPanel.add(buttonEvaluate = new JButton("Evaluate"));
content.setLayout(new BorderLayout());
content.add(buttonPanel,BorderLayout.SOUTH);
JPanel gridPanel = new JPanel();
gridPanel.setLayout(new GridLayout(7,2));
content.add(gridPanel,BorderLayout.CENTER);
gridPanel.add(new JLabel("Cylinders"));
gridPanel.add(this.textCyl = new JTextField("6"));
gridPanel.add(new JLabel("Displacement (cu.inch)"));
gridPanel.add(this.textDisp = new JTextField("183"));
gridPanel.add(new JLabel("Horse Power"));
gridPanel.add(this.textHP = new JTextField("230"));
gridPanel.add(new JLabel("Weight (lbs)"));
gridPanel.add(this.textWeight = new JTextField("2300"));
gridPanel.add(new JLabel("Acceleration (0-60 mph)"));
gridPanel.add(this.textAccel = new JTextField("6"));
gridPanel.add(new JLabel("Data file"));
gridPanel.add(this.textPath=new JTextField("c:\\mpg\\mpg.csv"));
gridPanel.add(new JLabel("Vehicle MPG"));
gridPanel.add(labelMPG = new JLabel("Click to Calc"));
buttonTrain.addActionListener(this);
buttonGenerate.addActionListener(this);
buttonEvaluate.addActionListener(this);
setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
}
public static void main(String[] args)
{
MilesPerGallon frame = new MilesPerGallon();
frame.setVisible(true);
}
private double parseField(JTextField field, String name) throws Exception
{
try
{
double result = Double.parseDouble(field.getText());
return result;
}
catch(Exception e)
{
throw new Exception("Please enter valid value: "+name);
}
}
public void actionPerformed(ActionEvent event) {
if( event.getSource()==this.buttonEvaluate ) {
performEvaluate();
}
else if( event.getSource()==this.buttonTrain ) {
performTrain();
}
else if( event.getSource()==this.buttonGenerate ) {
performBuild();
}
}
private void message(String text)
{
JOptionPane.showMessageDialog(null, text);
}
private void obtainPaths()
{
String filenameCSV = this.textPath.getText();
this.fileCSV = new File(filenameCSV);
if( !fileCSV.exists() )
{
message("Can't load: " + fileCSV);
}
this.directory = fileCSV.getParentFile();
if( !directory.exists() )
{
message("Can't load: " + directory);
}
this.encogFile = new File(this.directory,"mpg.eg");
}
private void performBuild() {
InputField inputMPG;
InputField inputCylinders;
InputField inputDisplacement;
InputField inputHorsePower;
InputField inputWeight;
InputField inputAcceleration;
final double lo = -0.5;
final double hi = 0.5;
obtainPaths();
DataNormalization norm = new DataNormalization();
norm.addInputField(inputMPG = new InputFieldCSV(false,fileCSV,0));
norm.addInputField(inputCylinders = new InputFieldCSV(true,fileCSV,1));
norm.addInputField(inputDisplacement = new InputFieldCSV(true,fileCSV,2));
norm.addInputField(inputHorsePower = new InputFieldCSV(true,fileCSV,3));
norm.addInputField(inputWeight = new InputFieldCSV(true,fileCSV,4));
norm.addInputField(inputAcceleration = new InputFieldCSV(true,fileCSV,5));
OutputField mpg;
norm.addOutputField(new OutputFieldRangeMapped(inputCylinders,lo,hi));
norm.addOutputField(new OutputFieldRangeMapped(inputDisplacement,lo,hi));
norm.addOutputField(new OutputFieldRangeMapped(inputHorsePower,lo,hi));
norm.addOutputField(new OutputFieldRangeMapped(inputWeight,lo,hi));
norm.addOutputField(new OutputFieldRangeMapped(inputAcceleration,lo,hi));
norm.addOutputField(mpg = new OutputFieldRangeMapped(inputMPG,lo,hi));
mpg.setIdeal(true);
NormalizationStorageNeuralDataSet target = new NormalizationStorageNeuralDataSet(5,1);
norm.setReport(new ConsoleStatusReportable());
norm.setTarget(target);
norm.process();
EncogPersistedCollection encog = new EncogPersistedCollection(this.encogFile);
encog.add("data", (EncogPersistedObject)target.getDataset());
BasicNetwork network = EncogUtility.simpleFeedForward(5, 7, 0, 1, true);
encog.add("network", network);
encog.add("norm", norm);
message("Success. Done processing CSV file.");
}
private void performTrain() {
if( !training )
{
training = true;
Thread t = new Thread(this);
t.start();
}
else
{
message("Already training");
}
}
private void performEvaluate() {
try {
obtainPaths();
if (calc == null) {
calc = new CalculateMPG(encogFile);
}
double cylinders = this.parseField(this.textCyl, "cylinders");
double displacement = this.parseField(this.textDisp, "displacement");
double horsePower = this.parseField(this.textHP, "horsePower");
double weight = this.parseField(this.textWeight, "weight");
double acceleration = this.parseField(this.textAccel, "acceleration");
double mpg = calc.calulate(
cylinders,
displacement,
horsePower,
weight,
acceleration);
this.labelMPG.setText(""+mpg);
} catch (Exception e) {
message(e.getMessage());
}
}
public void run() {
obtainPaths();
calc = null;
EncogPersistedCollection encog = new EncogPersistedCollection(this.encogFile);
BasicNetwork network = (BasicNetwork)encog.find("network");
NeuralDataSet trainingSet = (NeuralDataSet)encog.find("data");
EncogUtility.trainDialog(network, trainingSet);
encog.add("network", network);
training = false;
}
}