/*
* Encog(tm) Java Examples v3.4
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-examples
*
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.examples.neural.image;
import java.awt.Image;
import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import javax.imageio.ImageIO;
import org.encog.Encog;
import org.encog.EncogError;
import org.encog.ml.data.MLData;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.train.strategy.ResetStrategy;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.platformspecific.j2se.TrainingDialog;
import org.encog.platformspecific.j2se.data.image.ImageMLData;
import org.encog.platformspecific.j2se.data.image.ImageMLDataSet;
import org.encog.util.downsample.Downsample;
import org.encog.util.downsample.RGBDownsample;
import org.encog.util.downsample.SimpleIntensityDownsample;
import org.encog.util.simple.EncogUtility;
/**
* Should have an input file similar to:
*
* CreateTraining: width:16,height:16,type:RGB
* Input: image:./coins/dime.png, identity:dime
* Input: image:./coins/dollar.png, identity:dollar
* Input: image:./coins/half.png, identity:half dollar
* Input: image:./coins/nickle.png, identity:nickle
* Input: image:./coins/penny.png, identity:penny
* Input: image:./coins/quarter.png, identity:quarter
* Network: hidden1:100, hidden2:0
* Train: Mode:console, Minutes:1, StrategyError:0.25, StrategyCycles:50
* Whatis: image:./coins/dime.png
* Whatis: image:./coins/half.png
* Whatis: image:./coins/testcoin.png
*
*/
public class ImageNeuralNetwork {
class ImagePair {
private final File file;
private final int identity;
public ImagePair(final File file, final int identity) {
super();
this.file = file;
this.identity = identity;
}
public File getFile() {
return this.file;
}
public int getIdentity() {
return this.identity;
}
}
public static void main(final String[] args) {
if (args.length < 1) {
System.out
.println("Must specify command file. See source for format.");
} else {
try {
final ImageNeuralNetwork program = new ImageNeuralNetwork();
program.execute(args[0]);
} catch (final Exception e) {
e.printStackTrace();
}
}
Encog.getInstance().shutdown();
}
private final List<ImagePair> imageList = new ArrayList<ImagePair>();
private final Map<String, String> args = new HashMap<String, String>();
private final Map<String, Integer> identity2neuron = new HashMap<String, Integer>();
private final Map<Integer, String> neuron2identity = new HashMap<Integer, String>();
private ImageMLDataSet training;
private String line;
private int outputCount;
private int downsampleWidth;
private int downsampleHeight;
private BasicNetwork network;
private Downsample downsample;
private int assignIdentity(final String identity) {
if (this.identity2neuron.containsKey(identity.toLowerCase())) {
return this.identity2neuron.get(identity.toLowerCase());
}
final int result = this.outputCount;
this.identity2neuron.put(identity.toLowerCase(), result);
this.neuron2identity.put(result, identity.toLowerCase());
this.outputCount++;
return result;
}
public void execute(final String file) throws IOException {
final FileInputStream fstream = new FileInputStream(file);
final DataInputStream in = new DataInputStream(fstream);
final BufferedReader br = new BufferedReader(new InputStreamReader(in));
while ((this.line = br.readLine()) != null) {
executeLine();
}
in.close();
}
private void executeCommand(final String command,
final Map<String, String> args) throws IOException {
if (command.equals("input")) {
processInput();
} else if (command.equals("createtraining")) {
processCreateTraining();
} else if (command.equals("train")) {
processTrain();
} else if (command.equals("network")) {
processNetwork();
} else if (command.equals("whatis")) {
processWhatIs();
}
}
public void executeLine() throws IOException {
final int index = this.line.indexOf(':');
if (index == -1) {
throw new EncogError("Invalid command: " + this.line);
}
final String command = this.line.substring(0, index).toLowerCase()
.trim();
final String argsStr = this.line.substring(index + 1).trim();
final StringTokenizer tok = new StringTokenizer(argsStr, ",");
this.args.clear();
while (tok.hasMoreTokens()) {
final String arg = tok.nextToken();
final int index2 = arg.indexOf(':');
if (index2 == -1) {
throw new EncogError("Invalid command: " + this.line);
}
final String key = arg.substring(0, index2).toLowerCase().trim();
final String value = arg.substring(index2 + 1).trim();
this.args.put(key, value);
}
executeCommand(command, this.args);
}
private String getArg(final String name) {
final String result = this.args.get(name);
if (result == null) {
throw new EncogError("Missing argument " + name + " on line: "
+ this.line);
}
return result;
}
private void processCreateTraining() {
final String strWidth = getArg("width");
final String strHeight = getArg("height");
final String strType = getArg("type");
this.downsampleHeight = Integer.parseInt(strHeight);
this.downsampleWidth = Integer.parseInt(strWidth);
if (strType.equals("RGB")) {
this.downsample = new RGBDownsample();
} else {
this.downsample = new SimpleIntensityDownsample();
}
this.training = new ImageMLDataSet(this.downsample, false, 1, -1);
System.out.println("Training set created");
}
private void processInput() throws IOException {
final String image = getArg("image");
final String identity = getArg("identity");
final int idx = assignIdentity(identity);
final File file = new File(image);
this.imageList.add(new ImagePair(file, idx));
System.out.println("Added input image:" + image);
}
private void processNetwork() throws IOException {
System.out.println("Downsampling images...");
for (final ImagePair pair : this.imageList) {
final MLData ideal = new BasicMLData(this.outputCount);
final int idx = pair.getIdentity();
for (int i = 0; i < this.outputCount; i++) {
if (i == idx) {
ideal.setData(i, 1);
} else {
ideal.setData(i, -1);
}
}
final Image img = ImageIO.read(pair.getFile());
final ImageMLData data = new ImageMLData(img);
this.training.add(data, ideal);
}
final String strHidden1 = getArg("hidden1");
final String strHidden2 = getArg("hidden2");
this.training.downsample(this.downsampleHeight, this.downsampleWidth);
final int hidden1 = Integer.parseInt(strHidden1);
final int hidden2 = Integer.parseInt(strHidden2);
this.network = EncogUtility.simpleFeedForward(this.training
.getInputSize(), hidden1, hidden2,
this.training.getIdealSize(), true);
System.out.println("Created network: " + this.network.toString());
}
private void processTrain() throws IOException {
final String strMode = getArg("mode");
final String strMinutes = getArg("minutes");
final String strStrategyError = getArg("strategyerror");
final String strStrategyCycles = getArg("strategycycles");
System.out.println("Training Beginning... Output patterns="
+ this.outputCount);
final double strategyError = Double.parseDouble(strStrategyError);
final int strategyCycles = Integer.parseInt(strStrategyCycles);
final ResilientPropagation train = new ResilientPropagation(this.network, this.training);
train.addStrategy(new ResetStrategy(strategyError, strategyCycles));
if (strMode.equalsIgnoreCase("gui")) {
TrainingDialog.trainDialog(train, this.network, this.training);
} else {
final int minutes = Integer.parseInt(strMinutes);
EncogUtility.trainConsole(train, this.network, this.training,
minutes);
}
System.out.println("Training Stopped...");
}
public void processWhatIs() throws IOException {
final String filename = getArg("image");
final File file = new File(filename);
final Image img = ImageIO.read(file);
final ImageMLData input = new ImageMLData(img);
input.downsample(this.downsample, false, this.downsampleHeight,
this.downsampleWidth, 1, -1);
final int winner = this.network.winner(input);
System.out.println("What is: " + filename + ", it seems to be: "
+ this.neuron2identity.get(winner));
}
}