package opt.test;
import opt.OptimizationAlgorithm;
import opt.RandomizedHillClimbing;
import opt.example.NeuralNetworkOptimizationProblem;
import shared.DataSet;
import shared.ErrorMeasure;
import shared.FixedIterationTrainer;
import shared.Instance;
import shared.SumOfSquaresError;
import func.nn.backprop.BackPropagationNetwork;
import func.nn.backprop.BackPropagationNetworkFactory;
/**
* Test optimization for neural networks
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class XORTest {
/**
* Tests out the perceptron with the classic xor test
* @param args ignored
*/
public static void main(String[] args) {
BackPropagationNetworkFactory factory =
new BackPropagationNetworkFactory();
double[][][] data = {
{ { 1, 1, 1, 1 }, { 0 } },
{ { 1, 0, 1, 0 }, { 1 } },
{ { 0, 1, 0, 1 }, { 1 } },
{ { 0, 0, 0, 0 }, { 0 } }
};
Instance[] patterns = new Instance[data.length];
for (int i = 0; i < patterns.length; i++) {
patterns[i] = new Instance(data[i][0]);
patterns[i].setLabel(new Instance(data[i][1]));
}
BackPropagationNetwork network = factory.createClassificationNetwork(
new int[] { 4, 3, 1 });
ErrorMeasure measure = new SumOfSquaresError();
DataSet set = new DataSet(patterns);
NeuralNetworkOptimizationProblem nno = new NeuralNetworkOptimizationProblem(
set, network, measure);
OptimizationAlgorithm o = new RandomizedHillClimbing(nno);
FixedIterationTrainer fit = new FixedIterationTrainer(o, 5000);
fit.train();
Instance opt = o.getOptimal();
network.setWeights(opt.getData());
for (int i = 0; i < patterns.length; i++) {
network.setInputValues(patterns[i].getData());
network.run();
System.out.println("~~");
System.out.println(patterns[i].getLabel());
System.out.println(network.getOutputValues());
}
}
}