package opt.test;
import opt.*;
import opt.example.*;
import opt.ga.*;
import shared.*;
import func.nn.backprop.*;
import java.util.*;
import java.io.*;
import java.text.*;
/**
* Implementation of randomized hill climbing, simulated annealing, and genetic algorithm to
* find optimal weights to a neural network that is classifying abalone as having either fewer
* or more than 15 rings.
*
* @author Hannah Lau
* @version 1.0
*/
public class AbaloneTest {
private static Instance[] instances = initializeInstances();
private static int inputLayer = 7, hiddenLayer = 5, outputLayer = 1, trainingIterations = 1000;
private static BackPropagationNetworkFactory factory = new BackPropagationNetworkFactory();
private static ErrorMeasure measure = new SumOfSquaresError();
private static DataSet set = new DataSet(instances);
private static BackPropagationNetwork networks[] = new BackPropagationNetwork[3];
private static NeuralNetworkOptimizationProblem[] nnop = new NeuralNetworkOptimizationProblem[3];
private static OptimizationAlgorithm[] oa = new OptimizationAlgorithm[3];
private static String[] oaNames = {"RHC", "SA", "GA"};
private static String results = "";
private static DecimalFormat df = new DecimalFormat("0.000");
public static void main(String[] args) {
for(int i = 0; i < oa.length; i++) {
networks[i] = factory.createClassificationNetwork(
new int[] {inputLayer, hiddenLayer, outputLayer});
nnop[i] = new NeuralNetworkOptimizationProblem(set, networks[i], measure);
}
oa[0] = new RandomizedHillClimbing(nnop[0]);
oa[1] = new SimulatedAnnealing(1E11, .95, nnop[1]);
oa[2] = new StandardGeneticAlgorithm(200, 100, 10, nnop[2]);
for(int i = 0; i < oa.length; i++) {
double start = System.nanoTime(), end, trainingTime, testingTime, correct = 0, incorrect = 0;
train(oa[i], networks[i], oaNames[i]); //trainer.train();
end = System.nanoTime();
trainingTime = end - start;
trainingTime /= Math.pow(10,9);
Instance optimalInstance = oa[i].getOptimal();
networks[i].setWeights(optimalInstance.getData());
double predicted, actual;
start = System.nanoTime();
for(int j = 0; j < instances.length; j++) {
networks[i].setInputValues(instances[j].getData());
networks[i].run();
predicted = Double.parseDouble(instances[j].getLabel().toString());
actual = Double.parseDouble(networks[i].getOutputValues().toString());
double trash = Math.abs(predicted - actual) < 0.5 ? correct++ : incorrect++;
}
end = System.nanoTime();
testingTime = end - start;
testingTime /= Math.pow(10,9);
results += "\nResults for " + oaNames[i] + ": \nCorrectly classified " + correct + " instances." +
"\nIncorrectly classified " + incorrect + " instances.\nPercent correctly classified: "
+ df.format(correct/(correct+incorrect)*100) + "%\nTraining time: " + df.format(trainingTime)
+ " seconds\nTesting time: " + df.format(testingTime) + " seconds\n";
}
System.out.println(results);
}
private static void train(OptimizationAlgorithm oa, BackPropagationNetwork network, String oaName) {
System.out.println("\nError results for " + oaName + "\n---------------------------");
for(int i = 0; i < trainingIterations; i++) {
oa.train();
double error = 0;
for(int j = 0; j < instances.length; j++) {
network.setInputValues(instances[j].getData());
network.run();
Instance output = instances[j].getLabel(), example = new Instance(network.getOutputValues());
example.setLabel(new Instance(Double.parseDouble(network.getOutputValues().toString())));
error += measure.value(output, example);
}
System.out.println(df.format(error));
}
}
private static Instance[] initializeInstances() {
double[][][] attributes = new double[4177][][];
try {
BufferedReader br = new BufferedReader(new FileReader(new File("src/opt/test/abalone.txt")));
for(int i = 0; i < attributes.length; i++) {
Scanner scan = new Scanner(br.readLine());
scan.useDelimiter(",");
attributes[i] = new double[2][];
attributes[i][0] = new double[7]; // 7 attributes
attributes[i][1] = new double[1];
for(int j = 0; j < 7; j++)
attributes[i][0][j] = Double.parseDouble(scan.next());
attributes[i][1][0] = Double.parseDouble(scan.next());
}
br.close();
}
catch(Exception e) {
e.printStackTrace();
}
Instance[] instances = new Instance[attributes.length];
for(int i = 0; i < instances.length; i++) {
instances[i] = new Instance(attributes[i][0]);
// classifications range from 0 to 30; split into 0 - 14 and 15 - 30
instances[i].setLabel(new Instance(attributes[i][1][0] < 15 ? 0 : 1));
}
return instances;
}
}