package org.encog.examples.neural.xordisplay;
import java.util.Arrays;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.prop.TrainFlatNetworkResilient;
import org.encog.mathutil.randomize.ConsistentRandomizer;
import org.encog.mathutil.randomize.Randomizer;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.NeuralDataPair;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.data.basic.BasicNeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.util.logging.Logging;
import org.encog.util.simple.EncogUtility;
public class XORDisplay {
public final static int ITERATIONS = 10;
public static double XOR_INPUT[][] = { { 0.0, 0.0 }, { 1.0, 0.0 },
{ 0.0, 1.0 }, { 1.0, 1.0 } };
public static double XOR_IDEAL[][] = { { 0.0 }, { 1.0 }, { 1.0 }, { 0.0 } };
public static void displayWeights(FlatNetwork network) {
System.out.println("Weights:"
+ Arrays.toString(network.getWeights()));
}
public static void evaluate(FlatNetwork network, NeuralDataSet trainingSet )
{
double[] output = new double[1];
for(NeuralDataPair pair: trainingSet ) {
network.compute(pair.getInput().getData(), output);
System.out.println(pair.getInput().getData(0) + "," + pair.getInput().getData(1)
+ ", actual=" + output[0] + ",ideal=" + pair.getIdeal().getData(0));
}
}
public static FlatNetwork createNetwork() {
BasicNetwork network = EncogUtility
.simpleFeedForward(2, 4, 0, 1, false);
Randomizer randomizer = new ConsistentRandomizer(-1, 1);
randomizer.randomize(network);
return network.getStructure().getFlat().clone();
}
public static void main(String[] args) {
Logging.stopConsoleLogging();
NeuralDataSet trainingSet = new BasicNeuralDataSet(XOR_INPUT, XOR_IDEAL);
FlatNetwork network = createNetwork();
System.out.println("Starting Weights:");
displayWeights(network);
evaluate(network,trainingSet);
final TrainFlatNetworkResilient train = new TrainFlatNetworkResilient(
network, trainingSet);
for (int iteration = 1; iteration <= ITERATIONS; iteration++) {
train.iteration();
System.out.println();
System.out.println("*** Iteration #" + iteration);
System.out.println("Error: " + train.getError());
evaluate(network,trainingSet);
System.out.println("LastGrad:"
+ Arrays.toString(train.getLastGradient()));
System.out.println("Updates :"
+ Arrays.toString(train.getUpdateValues()));
displayWeights(network);
}
}
}