/**
* Copyright (C) 2017 Jan Schäfer (jansch@users.sourceforge.net)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jskat.ai.nn.util;
import static org.junit.Assert.fail;
import java.util.Arrays;
import org.encog.Encog;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
import org.encog.neural.networks.training.propagation.resilient.RPROPType;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.util.simple.EncogUtility;
import org.jskat.AbstractJSkatTest;
import org.junit.Ignore;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Tests for using neural networks with the Encog library.
*/
public class EncogNetworkWrapperTest extends AbstractJSkatTest {
private static final double EPSILON = 0.1;
/**
* Minimum difference between calculated output and desired result.
*/
private static final double MIN_DIFF = 0.01;
/**
* Maximum iterations for network learning
*/
private static final int MAX_ITERATIONS = 500;
/**
* Logger.
*/
private static Logger log = LoggerFactory.getLogger(EncogNetworkWrapperTest.class);
/**
* Tests the NetworkWrapper with an XOR example.
*/
@Test
@Ignore
public final void testXOR() {
int[] hiddenNeurons = { 3 };
NetworkTopology topo = new NetworkTopology(2, hiddenNeurons, 1);
INeuralNetwork network = new EncogNetworkWrapper(topo, false);
network.resetNetwork();
double[][] input = { { 1.0, 1.0 }, { 1.0, 0.0 }, { 0.0, 1.0 }, { 0.0, 0.0 } };
double[][] output = { { 0.0 }, // A XOR B
{ 1.0 }, { 1.0 }, { 0.0 } };
double error = 1000.0;
int i = 0;
int iteration = 0;
while (error > MIN_DIFF && iteration < MAX_ITERATIONS) {
error = network.adjustWeights(input[i], output[i]);
i = (i + 1) % input.length;
iteration++;
}
if (iteration == MAX_ITERATIONS) {
fail("Needed more than " + MAX_ITERATIONS + " iterations. Error: " + error);
} else {
log.info("Needed " + iteration + " iterations to learn.");
log.info("Testing network:");
for (int n = 0; n < input.length; n++) {
log.info("Input: " + input[n][0] + " " + input[n][1] + " Expected output: " + output[n][0]
+ " Predicted output: " + network.getPredictedOutcome(input[n]));
}
}
// assertTrue(network.getPredictedOutcome(input[0]) < output[0][0]
// + EPSILON);
// assertTrue(network.getPredictedOutcome(input[1]) > output[1][0]
// - EPSILON);
// assertTrue(network.getPredictedOutcome(input[2]) > output[2][0]
// - EPSILON);
// assertTrue(network.getPredictedOutcome(input[3]) < output[3][0]
// + EPSILON);
}
/**
* Tests the {@link BasicNetwork} directly with an XOR example.
*/
@Test
@Ignore
public final void testXORDirect() {
BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 2));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 3));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 1));
network.getStructure().finalizeStructure();
network.reset();
BasicMLDataSet trainingSet = new BasicMLDataSet();
double[][] input = { { 1.0, 1.0 }, { 1.0, 0.0 }, { 0.0, 1.0 }, { 0.0, 0.0 } };
double[][] output = { { 0.0 }, // A XOR B
{ 1.0 }, { 1.0 }, { 0.0 } };
for (int i = 0; i < input.length; i++) {
trainingSet.add(new BasicMLDataPair(new BasicMLData(input[i]), new BasicMLData(output[i])));
}
double error = 1000.0;
int i = 0;
int iteration = 0;
while (error > MIN_DIFF && iteration < MAX_ITERATIONS) {
i = (i + 1) % trainingSet.size();
Backpropagation trainer = new Backpropagation(network,
new BasicMLDataSet(Arrays.asList(trainingSet.get(i))));
trainer.setBatchSize(1);
trainer.iteration();
error = trainer.getError();
iteration++;
}
if (iteration == MAX_ITERATIONS) {
fail("Needed more than " + MAX_ITERATIONS + " iterations. Error: " + error);
} else {
log.debug("Needed " + iteration + " iterations to learn.");
log.debug("Testing network:");
for (int n = 0; n < input.length; n++) {
log.debug("Input: " + input[n][0] + " " + input[n][1] + " Expected output: " + output[n][0]
+ " Predicted output: " + network.compute(new BasicMLData(input[n])));
}
}
}
@Test
@Ignore
public void testXOROnlineTraining() {
double XOR_INPUT[][] = { { 0.0, 0.0 }, { 1.0, 0.0 }, { 0.0, 1.0 }, { 1.0, 1.0 } };
double XOR_IDEAL[][] = { { 0.0 }, { 1.0 }, { 1.0 }, { 0.0 } };
// Create a neural network, using the utility.
BasicNetwork network = EncogUtility.simpleFeedForward(2, 3, 2, 1, false);
network.reset();
// Create training data.
MLDataSet trainingSet = new BasicMLDataSet(XOR_INPUT, XOR_IDEAL);
// Train the neural network.
final Backpropagation train = new Backpropagation(network, trainingSet, 0.07, 0.02);
train.setBatchSize(1);
// Evaluate the neural network.
EncogUtility.trainToError(train, 0.01);
EncogUtility.evaluate(network, trainingSet);
// Shut down Encog.
Encog.getInstance().shutdown();
}
@Test
public void testXORResilientTraining() {
double XOR_INPUT[][] = { { 0.0, 0.0 }, { 1.0, 0.0 }, { 0.0, 1.0 }, { 1.0, 1.0 } };
double XOR_IDEAL[][] = { { 0.0 }, { 1.0 }, { 1.0 }, { 0.0 } };
// Create a neural network, using the utility.
BasicNetwork network = EncogUtility.simpleFeedForward(2, 3, 0, 1, false);
network.reset();
// Create training data.
MLDataSet trainingSet = new BasicMLDataSet(XOR_INPUT, XOR_IDEAL);
// Train the neural network.
final ResilientPropagation train = new ResilientPropagation(network, trainingSet);
train.setRPROPType(RPROPType.iRPROPp);
// Evaluate the neural network.
EncogUtility.trainToError(train, 0.01);
EncogUtility.evaluate(network, trainingSet);
// Shut down Encog.
Encog.getInstance().shutdown();
}
}