package com.rahul.trainer; import java.util.ArrayList; import java.util.Arrays; import org.neuroph.core.NeuralNetwork; import org.neuroph.core.data.DataSet; import org.neuroph.core.data.DataSetRow; import org.neuroph.core.events.LearningEvent; import org.neuroph.core.events.LearningEventListener; import org.neuroph.core.learning.LearningRule; import org.neuroph.nnet.MultiLayerPerceptron; import org.neuroph.nnet.learning.BackPropagation; import org.neuroph.nnet.learning.MomentumBackpropagation; import org.neuroph.util.TransferFunctionType; import com.rahul.bitboard.BitBoard; import com.rahul.numboard.NumBoard; public class MLP implements LearningEventListener { public static void main(String[] args) { new MLP().run(false); } /** * isBit should be set to true for training the network on bitboard data and * false for training from numboard data * * @param isBit */ public void run(boolean isBit) { DataSet trainingSet; // create training set if (isBit) { ArrayList<com.rahul.bitboard.RowData> data; trainingSet = new DataSet(BitBoard.BOARD_LENGTH * 12, 1); data = PrepareData.getBitBoardData("./data/bitboards_game4.bb"); for (com.rahul.bitboard.RowData row : data) { trainingSet.addRow(PrepareData.inputToDoubleArray(row .getInput()), PrepareData .expectedOutputToDoubleArray(row.getExpectedOutput())); } } else { ArrayList<com.rahul.numboard.RowData> data; trainingSet = new DataSet(NumBoard.BOARD_SIZE, 1); data = PrepareData.getNumBoardData("./data/numboards_game4.bb"); for (com.rahul.numboard.RowData row : data) { trainingSet.addRow(row.getInput(), row.getExpectedOutput()); } } // trainingSet.normalize(); // create multi layer perceptron MultiLayerPerceptron myMlPerceptron; if (isBit) { myMlPerceptron = new MultiLayerPerceptron( TransferFunctionType.TANH, BitBoard.BOARD_LENGTH * 12, BitBoard.BOARD_LENGTH * 24, BitBoard.BOARD_LENGTH * 12, 1); } else { myMlPerceptron = new MultiLayerPerceptron( TransferFunctionType.TANH, NumBoard.BOARD_SIZE, NumBoard.BOARD_SIZE * 4, NumBoard.BOARD_SIZE * 2, NumBoard.BOARD_SIZE, 1); } // enable batch if using MomentumBackpropagation if (myMlPerceptron.getLearningRule() instanceof MomentumBackpropagation) ((MomentumBackpropagation) myMlPerceptron.getLearningRule()) .setBatchMode(true); LearningRule learningRule = myMlPerceptron.getLearningRule(); learningRule.addListener(this); // learn the training set System.out.println("Training neural network..."); myMlPerceptron.learn(trainingSet); // test perceptron System.out.println("Testing trained neural network"); testNeuralNetwork(myMlPerceptron, trainingSet); // save trained neural network myMlPerceptron.save("2Layer_try01.nnet"); // load saved neural network NeuralNetwork loadedMlPerceptron = NeuralNetwork .load("2Layer_try01.nnet"); // test loaded neural network System.out.println("Testing loaded neural network"); testNeuralNetwork(loadedMlPerceptron, trainingSet); } public static void testNeuralNetwork(NeuralNetwork neuralNet, DataSet testSet) { for (DataSetRow testSetRow : testSet.getRows()) { neuralNet.setInput(testSetRow.getInput()); neuralNet.calculate(); double[] networkOutput = neuralNet.getOutput(); System.out .print("Input: " + Arrays.toString(testSetRow.getInput())); System.out.println(" Output: " + Arrays.toString(networkOutput)); } } @Override public void handleLearningEvent(LearningEvent event) { BackPropagation bp = (BackPropagation) event.getSource(); System.out.println(bp.getCurrentIteration() + ". iteration : " + bp.getTotalNetworkError()); } }