import org.junit.Test; import java.io.File; import java.util.ArrayList; import static org.fest.assertions.api.Assertions.assertThat; public class Launcher { @Test public void debug() { ClassLoader classLoader = getClass().getClassLoader(); File debugImage = new File(classLoader.getResource("digitdata/debugimages").getFile()); File debugLabel = new File(classLoader.getResource("digitdata/debuglabels").getFile()); MachineLearner ml = new MachineLearner(); ml.learn(debugImage, debugLabel); for (int index = 3; index <= 9; index++) { for (int y = 0; y < Parameters.NUMBER_OF_ROWS; y++) { for (int x = 0; x < Parameters.NUMBER_OF_COLUMNS; x++) { assertThat(ml.memory.getData(index)[y][0].getRatio()).isEqualTo(-1); } } } assertThat(ml.memory.getData(0)[0][0].getRatio()).isEqualTo(1) ; assertThat(ml.memory.getData(0)[0][3].getRatio()).isEqualTo(0) ; assertThat(ml.memory.getData(1)[0][0].getRatio()).isEqualTo(0) ; assertThat(ml.memory.getCommonData()[0][0].getRatio()).isEqualTo(0.5) ; assertThat(ml.memory.getCommonData()[0][3].getRatio()).isEqualTo(0.5) ; } @Test public void it_learns_from_images() { ClassLoader classLoader = getClass().getClassLoader(); File trainingimage = new File(classLoader.getResource("digitdata/trainingimages").getFile()); File traininglabel = new File(classLoader.getResource("digitdata/traininglabels").getFile()); File validationImage = new File(classLoader.getResource("digitdata/validationimages").getFile()); File validationLabel = new File(classLoader.getResource("digitdata/validationlabels").getFile()); MachineLearner ml = new MachineLearner(); ml.learn(trainingimage, traininglabel) ; ml.learn(validationImage,validationLabel); } @Test public void it_recognizes_an_input() { ClassLoader classLoader = getClass().getClassLoader(); File trainingimage = new File(classLoader.getResource("digitdata/trainingimages").getFile()); File traininglabel = new File(classLoader.getResource("digitdata/traininglabels").getFile()); File validationImage = new File(classLoader.getResource("digitdata/validationimages").getFile()); File validationLabel = new File(classLoader.getResource("digitdata/validationlabels").getFile()); File testImage = new File(classLoader.getResource("digitdata/testimages").getFile()); File testLabel = new File(classLoader.getResource("digitdata/testlabels").getFile()); MachineLearner ml = new MachineLearner(); ml.learn(trainingimage, traininglabel) ; ml.learn(validationImage,validationLabel); ArrayList<Integer> recognize = ml.recognize(testImage, testLabel); ArrayList<MatchEntity> matchEntities = Parser.getMatchingEntities(testImage, testLabel); int index = 0 ; int success=0, fail=0; for (MatchEntity matchEntity : matchEntities) { if(recognize.get(index)==matchEntity.getValue()) { success++; } else { fail++; } index++; } double probability = (success * 1.0 / (success + fail)) * 100; System.out.println(probability); assertThat(probability).isGreaterThan(75) ; } }