package org.deeplearning4j.examples.tictactoe; import org.datavec.api.util.ClassPathResource; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.io.File; import java.io.FileWriter; import java.util.ArrayList; import java.util.Date; import java.util.List; /** * This program generates basic data to be used in Training Program. * It performs following major steps * - generates all possible game states * - reward all game states generated in above step by finding winning state, assign it to value 1 and goes back upto first step through all steps and * calculates probability of each step in the game to make that move win game in the last state. * - Writes all states data along with probability of each state to win the game which was calculated in above step. * Note : * - Used <b>http://www.se16.info/hgb/tictactoe.htm</b> link to understand all possible number of moves in Tic-Tac-Toe game. * - Refer ReadMe.txt for detail explanation of each step. * <p> * <b>Developed by KIT Solutions Pvt. Ltd. (www.kitsol.com), 19-Jan-2017.</b> */ public class TicTacToeData { //All these private variables are not meant to be used from outside of the class. So, no getter/setter methods are provided. private List<INDArray> moveSequenceList = new ArrayList<>(); private List<INDArray> oddPlayerWiningList = new ArrayList<>(); private List<INDArray> evenPlayerWiningList = new ArrayList<>(); private List<INDArray> middleList = new ArrayList<>(); private List<INDArray> finalOutputArrayList = new ArrayList<>(); private List<Double> finalProbabilityValueList = new ArrayList<>(); private int previousMoveNumber = 0; /** * Main function that calls all major functions one-by-one to generate training data to be used in training program. */ public static void main(String[] args) throws Exception { String filePath = new ClassPathResource("TicTacToe").getFile().getAbsolutePath() + File.separator + "AllMoveWithReward.txt"; TicTacToeData data = new TicTacToeData(); System.out.println("Data Processing Started : " + (new Date()).toString()); data.generatePossibleGames(); System.out.println("All possible game state sequence generated, Finished At : " + (new Date()).toString()); data.rewardGameState(); System.out.println("Reward calculation finished : " + (new Date()).toString()); data.writeFinalData(filePath); System.out.println("File generation completed : " + (new Date()).toString()); } /** * Initiate generating all possible game states. Refer ReadMe.txt for detailed explanation. */ public void generatePossibleGames() { try { for (int index = 1; index <= 9; index++) { generateStateBasedOnMoveNumber(index); } } catch (Exception e) { System.out.println(e.toString()); } /*Here process odd and Draw using odd list*/ oddPlayerWiningList.addAll(moveSequenceList); } /** * This function allocates reward points to each state of the game based on the winning state. * For all elements in oddPlayerWiningList, evenPlayerWiningList and middleList (which contains intermediate entries before winning or draw). * Refer ReadMe.txt for detailed explanation. */ public void rewardGameState() { for (INDArray a : oddPlayerWiningList) { generateGameStatesAndRewardToIt(a, 0);//0 odd for position and 1 for even Position } for (INDArray a : evenPlayerWiningList) { generateGameStatesAndRewardToIt(a, 1); } for (INDArray element : middleList) { addToFinalOutputList(element, 0.5); } } /** * This function called by generatePossibleGames. It is the main function that generates all possible game states. * Refer ReadMe.txt for detailed explanation. */ private void generateStateBasedOnMoveNumber(int moveNumber) throws Exception { int newMoveNumber = previousMoveNumber + 1; if (newMoveNumber != moveNumber) { throw new Exception("Missing one or more moves between 1 to 9"); } else if (moveNumber > 9 || moveNumber < 1) { throw new Exception("Invalid move number"); } previousMoveNumber = newMoveNumber; List<INDArray> tempMoveSequenceList = new ArrayList<>(); tempMoveSequenceList.addAll(moveSequenceList); moveSequenceList.clear(); if (moveNumber == 1) { for (int i = 0; i < 9; i++) { INDArray temp2 = Nd4j.zeros(1, 9); temp2.putScalar(new int[]{0, i}, 1); moveSequenceList.add(temp2); } } else { boolean isOddMoveNumber = ((moveNumber % 2) != 0) ? true : false; int lengthOfTempMoveSequenceList = tempMoveSequenceList.size(); for (int i = 0; i < lengthOfTempMoveSequenceList; i++) { INDArray moveArraySequence = tempMoveSequenceList.get(i); for (int j = 0; j < 9; j++) { INDArray temp1 = Nd4j.zeros(1, 9); Nd4j.copy(moveArraySequence, temp1); if (moveArraySequence.getInt(j) == 0) { temp1.putScalar(new int[]{0, j}, moveNumber); if (moveNumber > 4) { if (checkWin(temp1, isOddMoveNumber)) { if (isOddMoveNumber == true) { oddPlayerWiningList.add(temp1); } else { evenPlayerWiningList.add(temp1); } } else { moveSequenceList.add(temp1); } } else { moveSequenceList.add(temp1); } } } } } } /** * Identify the game state win/Draw. */ private boolean checkWin(INDArray sequence, boolean isOdd) { double boardPosition1 = sequence.getDouble(0); double boardPosition2 = sequence.getDouble(1); double boardPosition3 = sequence.getDouble(2); double boardPosition4 = sequence.getDouble(3); double boardPosition5 = sequence.getDouble(4); double boardPosition6 = sequence.getDouble(5); double boardPosition7 = sequence.getDouble(6); double boardPosition8 = sequence.getDouble(7); double boardPosition9 = sequence.getDouble(8); boolean position1 = isOdd ? (sequence.getDouble(0) % 2.0 != 0) : (sequence.getDouble(0) % 2.0 == 0); boolean position2 = isOdd ? (sequence.getDouble(1) % 2.0 != 0) : (sequence.getDouble(1) % 2.0 == 0); boolean position3 = isOdd ? (sequence.getDouble(2) % 2.0 != 0) : (sequence.getDouble(2) % 2.0 == 0); boolean position4 = isOdd ? (sequence.getDouble(3) % 2.0 != 0) : (sequence.getDouble(3) % 2.0 == 0); boolean position5 = isOdd ? (sequence.getDouble(4) % 2.0 != 0) : (sequence.getDouble(4) % 2.0 == 0); boolean position6 = isOdd ? (sequence.getDouble(5) % 2.0 != 0) : (sequence.getDouble(5) % 2.0 == 0); boolean position7 = isOdd ? (sequence.getDouble(6) % 2.0 != 0) : (sequence.getDouble(6) % 2.0 == 0); boolean position8 = isOdd ? (sequence.getDouble(7) % 2.0 != 0) : (sequence.getDouble(7) % 2.0 == 0); boolean position9 = isOdd ? (sequence.getDouble(8) % 2.0 != 0) : (sequence.getDouble(8) % 2.0 == 0); if (((position1 && position2 && position3) && (boardPosition1 != 0 && boardPosition2 != 0 && boardPosition3 != 0)) || ((position4 && position5 && position6) && (boardPosition4 != 0 && boardPosition5 != 0 && boardPosition6 != 0)) || ((position7 && position8 && position9) && (boardPosition7 != 0 && boardPosition8 != 0 && boardPosition9 != 0)) || ((position1 && position4 && position7) && (boardPosition1 != 0 && boardPosition4 != 0 && boardPosition7 != 0)) || ((position2 && position5 && position8) && (boardPosition2 != 0 && boardPosition5 != 0 && boardPosition8 != 0)) || ((position3 && position6 && position9) && (boardPosition3 != 0 && boardPosition6 != 0 && boardPosition9 != 0)) || ((position1 && position5 && position9) && (boardPosition1 != 0 && boardPosition5 != 0 && boardPosition9 != 0)) || ((position3 && position5 && position7) && (boardPosition3 != 0 && boardPosition5 != 0 && boardPosition7 != 0))) { return true; } else { return false; } } /** * This function generate all intermediate (including winning) game state from the winning state available oddPlayerWiningList or evenPlayerWiningList * and pass it to calculateReward function to calculate probability of all states of winning game. * Refer ReadMe.txt for detailed explanation. */ private void generateGameStatesAndRewardToIt(INDArray output, int moveType) { INDArray maxArray = Nd4j.max(output); double maxNumber = maxArray.getDouble(0); List<INDArray> sequenceList = new ArrayList<>(); INDArray sequenceArray = Nd4j.zeros(1, 9); int move = 1; int positionOfDigit = 0; for (int i = 1; i <= maxNumber; i++) { INDArray newTempArray = Nd4j.zeros(1, 9); positionOfDigit = getPosition(output, i); if (i % 2 == moveType) { Nd4j.copy(sequenceArray, newTempArray); sequenceList.add(newTempArray); } else { Nd4j.copy(sequenceArray, newTempArray); middleList.add(newTempArray); } sequenceArray.putScalar(new int[]{0, positionOfDigit}, move); move = move * (-1); } move = move * (-1); INDArray newTempArray2 = Nd4j.zeros(1, 9); sequenceArray.putScalar(new int[]{0, positionOfDigit}, move); Nd4j.copy(sequenceArray, newTempArray2); sequenceList.add(newTempArray2); calculateReward(sequenceList); } /** * This function gives cell number of a particular move */ private int getPosition(INDArray array, double number) { for (int i = 0; i < array.length(); i++) { if (array.getDouble(i) == number) { return i; } } return 0; } /** * Function to calculate Temporal Difference. Refer ReadMe.txt for detailed explanation. */ private void calculateReward(List<INDArray> arrayList) { double probabilityValue = 0; for (int p = (arrayList.size() - 1); p >= 0; p--) { if (p == (arrayList.size() - 1)) { probabilityValue = 1.0; } else { probabilityValue = 0.5 + 0.1 * (probabilityValue - 0.5); } INDArray stateAsINDArray = arrayList.get(p); addToFinalOutputList(stateAsINDArray, probabilityValue); } } /** * This function adds game states to final list after calculating reward for each state of a winning game. */ private void addToFinalOutputList(INDArray inputLabelArray, double inputRewardValue) { int indexPosition = finalOutputArrayList.indexOf(inputLabelArray); if (indexPosition != -1) { double rewardValue = finalProbabilityValueList.get(indexPosition); double newUpdatedRewardValue = (rewardValue > inputRewardValue) ? rewardValue : inputRewardValue; finalProbabilityValueList.set(indexPosition, newUpdatedRewardValue); } else { finalOutputArrayList.add(inputLabelArray); finalProbabilityValueList.add(inputRewardValue); } } /** * This function writes all states of all games into file along with their probability values. */ public void writeFinalData(String saveFilePath) { try (FileWriter writer = new FileWriter(saveFilePath)) { List<String> finalStringListForFile = new ArrayList<>(); for (int index = 0; index < finalOutputArrayList.size(); index++) { INDArray arrayFromInputList = finalOutputArrayList.get(index); double rewardValue = finalProbabilityValueList.get(index); String tempString = arrayFromInputList.toString().replace('[', ' ').replace(']', ' ').replace(',', ':').replaceAll("\\s", ""); String tempString2 = tempString; tempString = tempString.replaceAll("-1", "2"); String output = tempString + " " + String.valueOf(rewardValue); int indexInList1 = finalStringListForFile.indexOf(output); if (indexInList1 == -1) { finalStringListForFile.add(output); } tempString2 = tempString2.replaceAll("1", "2").replaceAll("-2", "1"); String output2 = tempString2 + " " + String.valueOf(rewardValue); int indexInList2 = finalStringListForFile.indexOf(output2); if (indexInList2 == -1) { finalStringListForFile.add(output2); } } for (String s : finalStringListForFile) { writer.append(s); writer.append('\r'); writer.append('\n'); writer.flush(); } } catch (Exception i) { System.out.println(i.toString()); } } }