package vafusion.recog;
import java.awt.Dimension;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Scanner;
import java.util.Set;
import java.util.Vector;
import javax.swing.JFrame;
import org.neuroph.contrib.imgrec.ColorMode;
import org.neuroph.contrib.imgrec.FractionRgbData;
import org.neuroph.contrib.ocr.OcrHelper;
import org.neuroph.contrib.ocr.OcrUtils;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.TrainingSet;
import org.neuroph.util.TransferFunctionType;
@SuppressWarnings("serial")
public class NetworkTrainer extends JFrame {
/**
* Creates a new network, trains it using the data provided in the args, and saves it to a file.
* @throws IOException
* @throws InterruptedException
*/
public static void main(String[] args) throws IOException, InterruptedException {
String filename = args[0];
String[] data = Arrays.copyOfRange(args, 1, args.length);
List<HashMap<String, BufferedImage>> imageMaps = new ArrayList<HashMap<String, BufferedImage>>();
HashMap<String, BufferedImage> currentImageMap = new HashMap<String, BufferedImage>();
imageMaps.add(currentImageMap);
List<List<String>> labelsList = new ArrayList<List<String>>();
List<String> currentLabels = new ArrayList<String>();
labelsList.add(currentLabels);
List<TrainingSet> trainingSets = new ArrayList<TrainingSet>();
for(int i = 0; i < data.length; i++) {
if(currentLabels.contains(data[i].substring(4, 5))) {
currentLabels = null;
currentImageMap = null;
for(int j = 0; j < labelsList.size(); j++)
if(!((currentLabels = labelsList.get(j)).contains(data[i].substring(4, 5)))) {
currentImageMap = imageMaps.get(j);
break;
}
if(currentImageMap == null) {
currentLabels = new ArrayList<String>();
labelsList.add(currentLabels);
currentImageMap = new HashMap<String, BufferedImage>();
imageMaps.add(currentImageMap);
}
}
currentLabels.add(data[i].substring(4, 5));
currentImageMap.put(data[i].substring(4, 5), OcrUtils.loadImage(new File(data[i])));
currentLabels = labelsList.get(0); //reset the two lists so you're always
//filling the initial lists before moving forward
currentImageMap = imageMaps.get(0);
System.out.println(labelsList.size());
}
//System.out.println(OcrUtils.getFractionRgbDataForImages(imageMaps.get(0)));
for(int i = 0; i < 10; i++)
System.out.println(OcrUtils.getFractionRgbDataForImages(imageMaps.get(i)).size());
FractionRgbData frd = OcrUtils.getFractionRgbDataForImages(imageMaps.get(0)).get(labelsList.get(0).get(0));
Dimension size = new Dimension(frd.getWidth(), frd.getHeight());
System.out.println(size.getWidth() * size.getHeight());
Vector<Integer> layerCounts = new Vector<Integer>();
layerCounts.add(100);
NeuralNetwork nnet = OcrHelper.createNewNeuralNetwork("CharacterRecognizer", size, ColorMode.BLACK_AND_WHITE,
Character.getCharacters(), layerCounts , TransferFunctionType.TANH);
// NeuralNetwork nnet = NeuralNetwork.load(filename);
System.out.println(nnet.getInputNeurons().size());
for(int i = 0; i < labelsList.size(); i++)
trainingSets.add(OcrHelper.createBlackAndWhiteTrainingSet(labelsList.get(i),
OcrUtils.getFractionRgbDataForImages(imageMaps.get(i))));
System.out.println(trainingSets.size());
nnet.initializeWeights(-100, 100);
final Temp b = new Temp();
Thread w = new Thread(new Runnable() {
@Override
public void run() {
Scanner in = new Scanner(System.in);
in.nextLine();
System.out.println("stopping after the next pass");
b.running = false;
}
});
w.start();
int i = 0;
while(b.running) {
System.out.println("Pass: " + (i + 1));
Set<Thread> threadList = new HashSet<Thread>();
System.out.println("Spawning training threads.");
for(TrainingSet ts : trainingSets) {
nnet.learnInNewThread(ts);
threadList.add(nnet.getLearningThread());
}
System.out.println("All threads spawned. Count: " + threadList.size());
boolean learning = true;
Set<Thread> finishedList = new HashSet<Thread>();
int oldCount = -1;
while(learning) {
learning = false;
Thread k;
if(!threadList.contains(k = nnet.getLearningThread())) //in case we missed any of them.
threadList.add(k);
for(Thread t : threadList) {
if(t.isAlive())
learning = true;
else if(!finishedList.contains(t))
finishedList.add(t);
}
Thread.sleep(10000);
if(finishedList.size() > oldCount)
System.out.println("Finished count: " + (oldCount = finishedList.size()));
}
i++;
}
nnet.save(filename);
}
static class Temp {
boolean running = true;
}
}