package cx.prutser.sudoku.ocr;
import javax.imageio.ImageIO;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FilenameFilter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**
* @author Erik van Zijst
*/
public class Trainer {
private static final int WIDTH = 16;
private static final int HEIGHT = 16;
private String dir = ".";
private String filename = "config.net";
private long evals = 0L;
private long start = System.currentTimeMillis();
private SudokuDigitRecognizer engine;
static class TestValue {
private final double[] input;
private final int expectedDigit;
private final File file;
private long successCount = 0L;
public TestValue(int expectedDigit, byte[] pixels, File file) {
if (expectedDigit < 0 || expectedDigit > 9 || pixels == null || pixels.length != WIDTH * HEIGHT) {
throw new IllegalArgumentException("Pixel data out of range.");
} else {
this.expectedDigit = expectedDigit;
this.file = file;
this.input = OCRUtils.pixelsToPattern(pixels);
}
}
public int getExpectedDigit() {
return expectedDigit;
}
public double[] getInput() {
return input;
}
public File getFile() {
return file;
}
}
public Trainer(String... args) {
parseArgs(args);
File f = new File(filename);
if (f.exists()) {
InputStream in = null;
try {
in = new FileInputStream(f);
engine = new SudokuDigitRecognizer(in);
} catch (IOException e) {
System.err.println(String.format(
"Error reading initial network configuration (%s): %s",
f.getAbsolutePath(), e.getMessage()));
System.exit(1);
} finally {
try {
in.close();
} catch(IOException e) {}
}
} else {
engine = new SudokuDigitRecognizer();
}
}
private void train() {
final File baseDir = new File(dir);
if (!baseDir.exists() || !baseDir.isDirectory()) {
System.err.println(dir + " is not a directory.");
} else {
final List<TestValue> testValues = new ArrayList<TestValue>();
final String[] dirs = baseDir.list(new FilenameFilter() {
public boolean accept(File dir, String name) {
return new File(dir, name).isDirectory() &&
name.length() == 1 &&
"0123456789".contains(name);
}
});
for (String dir : dirs) {
final File[] files = new File(baseDir, dir).listFiles(new FilenameFilter() {
public boolean accept(File dir, String name) {
return name.endsWith(".png");
}
});
for (File file : files) {
try {
testValues.add(new TestValue(Integer.parseInt(dir), OCRUtils.getPixels(ImageIO.read(file)), file));
} catch(IllegalArgumentException iae) {
System.err.println("Error processing: " + file.getPath() + ": " + iae.getMessage());
} catch(IOException ioe) {
System.err.println("Error reading " + file.getPath() + ": " + ioe.getMessage());
}
}
System.out.println(String.format("Loaded %d images of \"%s\"", files.length, dir));
}
if (testValues.isEmpty()) {
System.out.println("No training images found.");
} else {
System.out.println(String.format(
"Training with learning rate %.3f and momentum %.3f.",
engine.getLearningRate(), engine.getMomentum()));
boolean exit = false;
final BufferedReader stdin = new BufferedReader(new InputStreamReader(System.in));
while (!exit) {
try {
if (stdin.ready()) {
stdin.readLine();
while (true) {
System.out.print(
"Interrupt program and save current state?\n" +
"Y (save and stop), K (stop), S (save and resume), enter (resume) : ");
String str = stdin.readLine();
if ("Y".equalsIgnoreCase(str)) {
doSave();
exit = true;
break;
} else if ("K".equalsIgnoreCase(str)) {
exit = true;
break;
} else if ("S".equalsIgnoreCase(str)) {
doSave();
break;
} else if (str.isEmpty()) {
break;
}
}
} else {
final int success = doRun(testValues);
if (success == testValues.size()) {
doLog(testValues, success);
doSave();
exit = true;
} else {
long now = System.currentTimeMillis();
if (now - start > 1000L) {
doLog(testValues, success);
start = now;
}
}
}
} catch(IOException e) {
}
}
}
}
}
private int doRun(List<TestValue> testValues) {
int success = 0;
for (TestValue testValue : testValues) {
if (engine.trainAndClassifyResult(testValue.getExpectedDigit(), testValue.getInput())) {
success++;
testValue.successCount++;
}
evals++;
}
return success;
}
private void doSave() {
final File f = new File(filename);
OutputStream out = null;
try {
engine.save(out = new FileOutputStream(f));
System.out.println("Network configuration saved to " + f.getAbsolutePath());
} catch (IOException e) {
System.err.println("Unable to save the network configuration to " + f.getAbsolutePath());
} finally {
try {
out.close();
} catch (IOException e) {}
}
}
private void doLog(List<TestValue> testValues, int success) {
System.out.println(String.format("Recognized %d of %d images (%.2f%%, %d evals). Hardest image: %s",
success, testValues.size(), (success / (float)testValues.size()) * 100,
evals, getHardestImages(testValues, 1).get(0).getPath()));
}
private List<File> getHardestImages(List<TestValue> testValues, int count) {
List<TestValue> copy = new ArrayList<TestValue>(testValues);
Collections.sort(copy, new Comparator<TestValue>() {
public int compare(TestValue o1, TestValue o2) {
return (int)(o1.successCount - o2.successCount);
}
});
copy = copy.subList(0, Math.min(count, copy.size()));
final List<File> files = new ArrayList<File>();
for (TestValue testValue : copy) {
files.add(testValue.getFile());
}
return files;
}
public static void main(String... args) {
new Trainer(args).train();
}
private void parseArgs(String... args) {
final String usage = "Usage: java " + getClass().getName() + " [OPTIONS]\n" +
"\n" +
"Trains the neural network to recognize the solver digits 1-9 and the blank tile\n" +
"using the tile images from a specified directory. The tile images must be 8-bit\n" +
"gray scale in 16x16 resolution and png format.\n" +
"\n" +
"OPTIONS\n" +
" -d, --dir directory containing the tile images (defaults to .)\n" +
" -f, --file save learned network state to file (defaults to config.net)\n" +
" if this file exists at startup, it is used at initialization.\n" +
" -h, --help print this help message and exit.";
boolean exit = false;
try {
for (int i = 0; !exit && i < args.length; i++) {
if ("-d".equals(args[i]) || "--dir".equals(args[i])) {
dir = args[++i];
} else if("-f".equals(args[i]) || "--file".equals(args[i])) {
filename = args[++i];
} else if("-h".equals(args[i]) || "--help".equals(args[i])) {
exit = true;
}
}
} catch(ArrayIndexOutOfBoundsException e) {
exit = true;
}
if (exit) {
System.err.println(usage);
System.exit(1);
}
}
}