package com.rahul.trainer;
import java.util.ArrayList;
import java.util.List;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.nnet.Hopfield;
import com.rahul.bitboard.BitBoard;
import com.rahul.numboard.NumBoard;
public class HopfieldNetwork {
/**
* TRUE if training on bitboards FALSE if training on numboards
*/
private static boolean isBit = false;
public static void main(String args[]) {
DataSet trainingSet;
// create training set
if (isBit) {
ArrayList<com.rahul.bitboard.RowData> data;
trainingSet = new DataSet(BitBoard.BOARD_LENGTH * 12, 1);
data = PrepareData.getBitBoardData("./data/bitboards_game4.bb");
for (com.rahul.bitboard.RowData row : data) {
trainingSet.addRow(PrepareData.inputToDoubleArray(row
.getInput()), PrepareData
.expectedOutputToDoubleArray(row.getExpectedOutput()));
}
} else {
ArrayList<com.rahul.numboard.RowData> data;
trainingSet = new DataSet(NumBoard.BOARD_SIZE, 1);
data = PrepareData.getNumBoardData("./data/numboards_game4.bb");
Normalizer.normalize(data);
for (com.rahul.numboard.RowData row : data) {
trainingSet.addRow(row.getInput(), row.getExpectedOutput());
}
}
// create hopfield network
Hopfield myHopfield;
if (isBit)
myHopfield = new Hopfield(BitBoard.BOARD_LENGTH * 12);
else {
myHopfield = new Hopfield(NumBoard.BOARD_SIZE);
}
// learn the training set
myHopfield.learn(trainingSet);
// test hopfield network
System.out.println("Testing network");
// create test set
DataSet testData = trainingSet;
/*
* data = PrepareData.getData("./data/bitboards_game4.bb"); DataSet
* testData = new DataSet(BitBoard.BOARD_LENGTH * 12, 1); for (RowData
* row : data) {
* testData.addRow(PrepareData.inputToDoubleArray(row.getInput()),
* PrepareData.expectedOutputToDoubleArray(row .getExpectedOutput())); }
*/
// print network output
int counter = 0;
List<DataSetRow> testSet = testData.getRows();
for (DataSetRow testSetRow : testSet) {
myHopfield.setInput(testSetRow.getInput());
myHopfield.calculate();
myHopfield.calculate();
myHopfield.calculate();
myHopfield.calculate();
myHopfield.calculate();
myHopfield.calculate();
myHopfield.calculate();
myHopfield.calculate();
myHopfield.calculate();
double[] networkOutput = myHopfield.getOutput();
printArray(testSetRow.getInput());
printArray(networkOutput);
printPairwiseError(testSetRow.getInput(), networkOutput);
counter++;
}
System.out.println("Total rows : " + counter);
}
private static void printArray(double[] a) {
for (int i = 0; i < a.length; i++)
System.out.printf("%.0f", a[i]);
System.out.println();
}
private static void printPairwiseError(double[] a, double[] b) {
float error = 0f;
for (int i = 0; i < a.length; i++) {
if (a[i] != b[i])
error++;
}
System.out.println(error / a.length);
}
}