import java.io.*; import java.util.ArrayList; public class MachineLearner { public Memory memory; public MachineLearner() { memory = new Memory(); } public void learn(File imageFile, File labelFile) { ArrayList<MatchEntity> matchEntities = Parser.getMatchingEntities(imageFile, labelFile); //printMatchEntities(matchEntities); //save(matchEntities); for (MatchEntity matchEntity : matchEntities) { int[][] results = processImage(matchEntity); memory.upgrade(results, matchEntity.getValue()); } } private int[][] processImage(MatchEntity matchEntity) { return new ImageProcessor().process(matchEntity.getImage()); } private static void save(ArrayList<MatchEntity> matchEntities) { PrintWriter writer = null; try { writer = new PrintWriter("toDelete.txt", "UTF-8"); } catch (FileNotFoundException e) { e.printStackTrace(); } catch (UnsupportedEncodingException e) { e.printStackTrace(); } for (int i = 0; i < matchEntities.size(); i++) { writer.write(matchEntities.get(i).toString() + "\n"); } writer.close(); } private static void printMatchEntities(ArrayList<MatchEntity> matchEntities) { for (int i = 0; i < matchEntities.size(); i++) { matchEntities.get(i).print(); } } public ArrayList<Integer> recognize(File image, File label) { ArrayList<MatchEntity> matchEntities = Parser.getMatchingEntities(image, label); ArrayList<Integer> predictions = new ArrayList(); for (MatchEntity matchEntity : matchEntities) { double[] probabilities= getProbabilities(matchEntity); predictions.add(getIndexOfHighestProbability(probabilities)); } return predictions; } private int getIndexOfHighestProbability(double[] probabilities) { double max=1; int index=-1; for(int i=0;i<10;i++) { if(probabilities[i]>max) { max=probabilities[i]; index=i; } } return index; } private double[] getProbabilities(MatchEntity matchEntity) { double[] probability = new double[10]; int[][] results = processImage(matchEntity); for(int i =0; i< 10; i++) { probability[i] = bayes(results,i); } return probability; } private double bayes(int[][] results, int i) { double probability=1; for (int x = 0; x < results.length; x++) { for (int y = 0; y < results[x].length; y++) { if(results[y][x]==Parameters.BLACK) { probability*=memory.getData(i)[y][x].getRatio()/memory.getCommonData()[y][x].getRatio(); } else{ probability*=memory.getData(i)[y][x].getOppositeRatio()/memory.getCommonData()[y][x].getOppositeRatio(); } } } return probability; } }