//----------------------------------------------------------------------------// // // // N e u r a l N e t w o r k T e s t // // // //----------------------------------------------------------------------------// // <editor-fold defaultstate="collapsed" desc="hdr"> // // Copyright (C) Hervé Bitteur 2000-2011. All rights reserved. // // This software is released under the GNU General Public License. // // Goto http://kenai.com/projects/audiveris to report bugs or suggestions. // //----------------------------------------------------------------------------// // </editor-fold> package omr.math; import omr.Main; import omr.util.BaseTestCase; import static junit.framework.Assert.*; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; //import org.testng.annotations.*; import java.io.FileOutputStream; import javax.xml.bind.*; /** * Class * <code>NeuralNetworkTest</code> performs unit tests on * NeuralNetwork class. * * @author Hervé Bitteur * @version $Id$ */ public class NeuralNetworkTest extends BaseTestCase { //~ Static fields/initializers --------------------------------------------- private static final double maxMSE = 0.3; private static NeuralNetwork nn; //~ Methods ---------------------------------------------------------------- //-------------------// // testBackupRestore // //-------------------// //@Test public void testBackupRestore () { double[][] inputs = new double[][]{ {0, 0}, {1, 0}, {0, 1}, {1, 1} }; double[][] desiredOutputs = new double[][]{ {0}, {1}, {1}, {0} }; NeuralNetwork.Monitor monitor = new MyMonitor(); do { nn = createNetwork(); } while (nn.train(inputs, desiredOutputs, monitor) > maxMSE); nn.dump(); NeuralNetwork.Backup backup = nn.backup(); // Check behavior on incompatible network NeuralNetwork pp = createNetwork(1, 2, 3); try { pp.restore(backup); fail("Expected IllegalArgumentException"); } catch (IllegalArgumentException expected) { checkException(expected); } pp = createNetwork(); // Check behavior on a null backup try { pp.restore(null); fail("Expected IllegalArgumentException"); } catch (IllegalArgumentException expected) { checkException(expected); } // Check normal backup pp.restore(backup); assertNears( "0 xor 0 should be 0", 0d, pp.run(new double[]{0, 0}, null, null)[0], 0.1d); assertNears( "1 xor 0 should be 1", 1d, pp.run(new double[]{1, 0}, null, null)[0], 0.1d); assertNears( "0 xor 1 should be 1", 1d, pp.run(new double[]{0, 1}, null, null)[0], 0.1d); assertNears( "1 xor 1 should be 0", 0d, pp.run(new double[]{1, 1}, null, null)[0], 0.1d); } //-----------------// // testMarshalling // //-----------------// //@Test public void testMarshalling () throws JAXBException, FileNotFoundException { double[][] inputs = new double[][]{ {0, 0}, {1, 0}, {0, 1}, {1, 1} }; double[][] desiredOutputs = new double[][]{ {0}, {1}, {1}, {0} }; NeuralNetwork.Monitor monitor = new MyMonitor(); do { nn = createNetwork(); } while (nn.train(inputs, desiredOutputs, monitor) > maxMSE); nn.dump(); File dir = new File("data/temp"); dir.mkdirs(); File file = new File(dir, "nn.xml"); // Marshalling System.out.println("Marshalling to " + file); nn.marshal(new FileOutputStream(file)); System.out.println("Marshalled"); // Unmarshalling System.out.println("Unmarshalling from " + file); nn = NeuralNetwork.unmarshal(new FileInputStream(file)); System.out.println("Unmarshalled"); Main.dumping.dump(nn); nn.dump(); } //--------// // testOr // //--------// //@Test public void testOr () { double[][] inputs = new double[][]{ {0, 0}, {1, 0}, {0, 1}, {1, 1} }; double[][] desiredOutputs = new double[][]{ {0}, {1}, {1}, {1} }; NeuralNetwork.Monitor monitor = new MyMonitor(); do { nn = createNetwork(); } while (nn.train(inputs, desiredOutputs, monitor) > maxMSE); nn.dump(); assertNears( "0 or 0 should be 0", 0d, nn.run(new double[]{0, 0}, null, null)[0], 0.1d); assertNears( "1 or 0 should be 1", 1d, nn.run(new double[]{1, 0}, null, null)[0], 0.1d); assertNears( "0 or 1 should be 1", 1d, nn.run(new double[]{0, 1}, null, null)[0], 0.1d); assertNears( "1 or 1 should be 1", 1d, nn.run(new double[]{1, 1}, null, null)[0], 0.1d); } //---------// // testXOr // //---------// //@Test public void testXOr () { double[][] inputs = new double[][]{ {0, 0}, {1, 0}, {0, 1}, {1, 1} }; double[][] desiredOutputs = new double[][]{ {0}, {1}, {1}, {0} }; NeuralNetwork.Monitor monitor = new MyMonitor(); do { nn = createNetwork(); } while (nn.train(inputs, desiredOutputs, monitor) > maxMSE); nn.dump(); assertNears( "0 xor 0 should be 0", 0d, nn.run(new double[]{0, 0}, null, null)[0], 0.1d); assertNears( "1 xor 0 should be 1", 1d, nn.run(new double[]{1, 0}, null, null)[0], 0.1d); assertNears( "0 xor 1 should be 1", 1d, nn.run(new double[]{0, 1}, null, null)[0], 0.1d); assertNears( "1 xor 1 should be 0", 0d, nn.run(new double[]{1, 1}, null, null)[0], 0.1d); } //---------------// // createNetwork // //---------------// private NeuralNetwork createNetwork (int inputSize, int hiddenSize, int outputSize) { double amplitude = 0.5; double learningRate = 0.25; double momentum = 0.25; double maxError = 0.02; int epochs = 500000; return new NeuralNetwork( inputSize, hiddenSize, outputSize, amplitude, null, null, learningRate, momentum, maxError, epochs); } private NeuralNetwork createNetwork () { return createNetwork(2, 2, 1); } //~ Inner Classes ---------------------------------------------------------- //-----------// // MyMonitor // //-----------// private static class MyMonitor implements NeuralNetwork.Monitor { //~ Methods ------------------------------------------------------------ public void epochEnded (int epochIndex, double mse) { if ((epochIndex % 10000) == 0) { System.out.println( "epochEnded." + " epochIndex=" + epochIndex + " mse=" + mse); // Test for convergence if (epochIndex > 100000) { nn.stop(); } } } public void trainingStarted (final int epochIndex, final double mse) { System.out.println( "trainingStarted." + " epochIndex=" + epochIndex + " mse=" + mse); } } }