package cx.prutser.sudoku.ocr; import com.digiburo.backprop1.BackProp; import com.digiburo.backprop1.Pattern; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.Arrays; /** * OCR engine for recognizing the digits [1-9] and the blank tile in a sudoku. * * @author Erik van Zijst */ public class SudokuDigitRecognizer { private static final double one = 0.9999999999D; private static final double zero = 0.0000000001D; private static final int WIDTH = 16; private static final int HEIGHT = 16; private static final int hiddenLayer = 128; private static final float learningRate = 0.25F; // was 0.25 private static final float momentum = 0.30F; // was 0.75 private static final double[][] outputPattern = new double[10][10]; private final BackProp backProp; static { for (int i = 0; i < outputPattern.length; i++) { Arrays.fill(outputPattern[i], zero); outputPattern[i][i] = one; } } /** * Creates a new, unconfigured digit recognizer that has to be trained first. */ public SudokuDigitRecognizer() { backProp = new BackProp(WIDTH * HEIGHT, hiddenLayer, 10, learningRate, momentum); } /** * Creates a new instance of this digit recognizer, initialized with the * network configuration stored in the specified file. * * @param in * @throws IOException when the network configuration could not be read. */ public SudokuDigitRecognizer(InputStream in) throws IOException { try { backProp = new BackProp(in); } catch(ClassNotFoundException e) { throw new IOException("The network configuration file could not parsed.", e); } } public float getLearningRate() { return learningRate; } public float getMomentum() { return momentum; } /** * * @param pixels * @return the successfully recognized digit ([0-9]), or -1 * when the image could not be recognized. */ public int testAndClassify(double[] pixels) { return classifyResult(test(pixels)); } public double[] test(double[] pixels) { if (pixels.length != WIDTH * HEIGHT) { throw new IllegalArgumentException("Unsupported tile size: " + pixels.length); } else { backProp.setInputPattern(pixels); backProp.runNetwork(); return backProp.getOutputPattern(); } } protected int classifyResult(double[] result) { if (result.length != 10) { throw new IllegalArgumentException("Invalid array length: " + result.length); } else { for (int i = 0; i < outputPattern.length; i++) { boolean recognized = true; for (int j = 0; j < result.length; j++) { if (round2(result[j]) != round2(outputPattern[i][j])) { recognized = false; break; } } if (recognized) { return i; } } return -1; } } public boolean trainAndClassifyResult(int expectedDigit, double[] pixels) throws IllegalArgumentException { double[] result = train(expectedDigit, pixels); for (int i = 0; i < result.length; i++) { if (round1(result[i]) != round2(outputPattern[expectedDigit][i])) { return false; } } return true; } /** * Trains the network on one tile image. * * @param pixels the 8-bit gray scale pixel data (must be between 0 and 1). * @param expectedDigit the expected outcome ([0-9]). * @throws IllegalArgumentException when the pixel data is not 16x16 wide, * or the expectedDigit is out of range. * @return <code>true</code> if the image was successfully recognized, * <code>false</code> if not. */ public double[] train(int expectedDigit, double[] pixels) throws IllegalArgumentException { if (expectedDigit < 0 || expectedDigit > 9 || pixels == null || pixels.length != WIDTH * HEIGHT) { throw new IllegalArgumentException("Input out of range."); } else { backProp.setInputPattern(pixels); backProp.runNetwork(); backProp.trainNetwork(new Pattern(pixels, outputPattern[expectedDigit])); return backProp.getOutputPattern(); } } /** * Map an answer from the network to a value suitable for truth comparison * * @param candidate value from network * @return value for comparison w/truth */ private int round1(double candidate) { if (candidate > 0.85D) { return 1; } else if (candidate < 0.15D) { return 0; } else { return -1; } } /** * Map a truth value to a value suitable for comparison * * @param candidate value from truth pattern * @return value for comparison w/truth */ private int round2(double candidate) { if (candidate >= 0.5D) { return 1; } return 0; } /** * Writes the current network configuration to the specified stream. * * @param out * @throws IOException */ public void save(OutputStream out) throws IOException { backProp.writer(out); } }