package bots.mctsbot.ai.opponentmodels.weka; import java.io.File; import java.io.IOException; import java.net.URL; import org.apache.log4j.Logger; import weka.core.Instance; import bots.mctsbot.common.elements.player.PlayerId; public class ARFFPlayer { private final static Logger logger = Logger.getLogger(ARFFPlayer.class); private final Object player; private ARFFFile preCheckBetFile; private ARFFFile postCheckBetFile; private ARFFFile preFoldCallRaiseFile; private ARFFFile postFoldCallRaiseFile; private ARFFFile showdownFile; private boolean modelCreated = false; private WekaRegressionModel model = null; private WekaOptions config = new WekaOptions(); private ActionTrackingVisitor actions = null; private long writeCounter = 0; public ARFFPlayer(Object player, WekaRegressionModel baseModel, WekaOptions config, ActionTrackingVisitor actions) { if (!config.useOnlineLearning()) throw new IllegalStateException("ARFFPlayer can only be used with online learning!"); this.player = player; this.config = config; this.model = baseModel; this.actions = actions; try { // Begin new code -- can be better URL codeSource = getClass().getProtectionDomain().getCodeSource().getLocation(); File f = new File(codeSource.toURI()); while (!f.isDirectory()) { f = f.getParentFile(); if (f == null) { throw new IOException("Cannot find location for ARFFFiles"); } } String path = "./data/mctsbot/"; // End new code // String path = (getClass().getProtectionDomain().getCodeSource() // .getLocation().getPath() + folder).replace("%20", " "); // End old code preCheckBetFile = new ARFFFile(path, player, "PreCheckBet.arff", ARFFPropositionalizer.getPreCheckBetInstance().toString(), config); postCheckBetFile = new ARFFFile(path, player, "PostCheckBet.arff", ARFFPropositionalizer.getPostCheckBetInstance().toString(), config); preFoldCallRaiseFile = new ARFFFile(path, player, "PreFoldCallRaise.arff", ARFFPropositionalizer.getPreFoldCallRaiseInstance().toString(), config); postFoldCallRaiseFile = new ARFFFile(path, player, "PostFoldCallRaise.arff", ARFFPropositionalizer.getPostFoldCallRaiseInstance().toString(), config); showdownFile = new ARFFFile(path, player, "Showdown.arff", ARFFPropositionalizer.getShowdownInstance().toString(), config); } catch (IOException io) { throw new RuntimeException(io); } catch (Exception e) { throw new RuntimeException("Unable to create set of instances"); } } public void close() throws IOException { if (model != null) { preCheckBetFile.close(); postCheckBetFile.close(); preFoldCallRaiseFile.close(); postFoldCallRaiseFile.close(); showdownFile.close(); } } private String msgModelNotReady(ARFFFile file) { return file.getName().substring(0, file.getName().indexOf(".")) + " is not ready to be learned " + "(learning examples: " + file.getNrExamples() + " < " + config.getMinimalLearnExamples() + " required)"; } public double getAccuracy() { return actions.getAccuracy((PlayerId) player); } public void learnNewModel() { // if (!(preCheckBetFile.isModelReady() && postCheckBetFile.isModelReady() && preFoldCallRaiseFile.isModelReady() // && postFoldCallRaiseFile.isModelReady() && showdownFile.isModelReady())) { // System.out.println("\n MODEL NOT READY \n"); // return; // } System.out.println(""); logger.info("Learning new opponentModel for player " + player); modelCreated = true; // learning preCheckBetModel if (preCheckBetFile.isModelReady()) learnPreCheckBet(); else logger.info(msgModelNotReady(preCheckBetFile)); // learning postCheckBetModel if (postCheckBetFile.isModelReady()) learnPostCheckBet(); else logger.info(msgModelNotReady(postCheckBetFile)); // learning preFoldCallRaiseModel if (preFoldCallRaiseFile.isModelReady()) learnPreFoldCallRaise(); else logger.info(msgModelNotReady(preFoldCallRaiseFile)); // learning postFoldCallRaiseModel if (postFoldCallRaiseFile.isModelReady()) learnPostFoldCallRaise(); else logger.info(msgModelNotReady(postFoldCallRaiseFile)); // learning showdownModel if (showdownFile.isModelReady()) learnShowdown(); else logger.info(msgModelNotReady(showdownFile)); System.out.println(""); } public boolean writeAllowed() { return !modelCreated || (modelCreated && (config.continuousLearning() || config.continueAfterCreation())); } public boolean learningAllowed() { return (modelCreated && config.continuousLearning()) || (!modelCreated && (writeCounter >= config.modelCreationTreshold())); } public boolean modelCreated() { return modelCreated; } public void addPreCheckBetPrediction(Prediction p) { preCheckBetFile.addPrediction(p); } public void writePreCheckBet(Instance instance) { if (writeAllowed()) { preCheckBetFile.write(instance); incrementWriteCounter(); if (learningAllowed()) { if (modelCreated) learnPreCheckBet(); else learnNewModel(); } } } public void learnPreCheckBet() { try { logger.trace("Learning preBetModel for player " + player); model.setPreBetModel(preCheckBetFile.createModel("preBet", "betProb", new String[] { "action" })); } catch (Exception e) { e.printStackTrace(); } } public void addPostCheckBetPrediction(Prediction p) { postCheckBetFile.addPrediction(p); } public void writePostCheckBet(Instance instance) { if (writeAllowed()) { postCheckBetFile.write(instance); incrementWriteCounter(); if (learningAllowed()) { if (modelCreated) learnPostCheckBet(); else learnNewModel(); } } } public void learnPostCheckBet() { try { logger.trace("Learning postBetModel for player " + player); model.setPostBetModel(postCheckBetFile.createModel("postBet", "betProb", new String[] { "action" })); } catch (Exception e) { e.printStackTrace(); } } public void addPreFoldCallRaisePrediction(Prediction p) { preFoldCallRaiseFile.addPrediction(p); } public void writePreFoldCallRaise(Instance instance) { if (writeAllowed()) { preFoldCallRaiseFile.write(instance); incrementWriteCounter(); if (learningAllowed()) { if (modelCreated) learnPreFoldCallRaise(); else learnNewModel(); } } } public void learnPreFoldCallRaise() { try { logger.trace("Learning preFoldModel for player " + player); model.setPreFoldModel(preFoldCallRaiseFile.createModel("preFold", "foldProb", new String[] { "callProb", "raiseProb", "action" })); logger.trace("Learning preCallModel for player " + player); model.setPreCallModel(preFoldCallRaiseFile.createModel("preCall", "callProb", new String[] { "foldProb", "raiseProb", "action" })); logger.trace("Learning preRaiseModel for player " + player); model.setPreRaiseModel(preFoldCallRaiseFile.createModel("preRaise", "raiseProb", new String[] { "callProb", "foldProb", "action" })); } catch (Exception e) { e.printStackTrace(); } } public void addPostFoldCallRaisePrediction(Prediction p) { postFoldCallRaiseFile.addPrediction(p); } public void writePostFoldCallRaise(Instance instance) { if (writeAllowed()) { postFoldCallRaiseFile.write(instance); incrementWriteCounter(); if (learningAllowed()) { if (modelCreated) learnPostFoldCallRaise(); else learnNewModel(); } } } public void learnPostFoldCallRaise() { try { logger.trace("Learning postFoldModel for player " + player); model.setPostFoldModel(postFoldCallRaiseFile.createModel("postFold", "foldProb", new String[] { "callProb", "raiseProb", "action" })); logger.trace("Learning postCallModel for player " + player); model.setPostCallModel(postFoldCallRaiseFile.createModel("postCall", "callProb", new String[] { "foldProb", "raiseProb", "action" })); logger.trace("Learning postRaiseModel for player " + player); model.setPostRaiseModel(postFoldCallRaiseFile.createModel("postRaise", "raiseProb", new String[] { "callProb", "foldProb", "action" })); } catch (Exception e) { e.printStackTrace(); } } public void addShowdownPrediction(Prediction p) { showdownFile.addPrediction(p); } public void writeShowdown(Instance instance) { if (writeAllowed()) { showdownFile.write(instance); incrementWriteCounter(); if (learningAllowed()) { if (modelCreated) learnShowdown(); else learnNewModel(); } } } public void learnShowdown() { try { logger.trace("Learning showdown0Model for player " + player); model.setShowdown0Model(showdownFile.createModel("showdown0", "part0Prob", new String[] { "part1Prob", "part2Prob", "part3Prob", "part4Prob", "part5Prob", "avgPartition" })); logger.trace("Learning showdown1Model for player " + player); model.setShowdown1Model(showdownFile.createModel("showdown1", "part1Prob", new String[] { "part0Prob", "part2Prob", "part3Prob", "part4Prob", "part5Prob", "avgPartition" })); logger.trace("Learning showdown2Model for player " + player); model.setShowdown2Model(showdownFile.createModel("showdown5", "part2Prob", new String[] { "part0Prob", "part1Prob", "part3Prob", "part4Prob", "part5Prob", "avgPartition" })); logger.trace("Learning showdown3Model for player " + player); model.setShowdown3Model(showdownFile.createModel("showdown3", "part3Prob", new String[] { "part0Prob", "part1Prob", "part2Prob", "part4Prob", "part5Prob", "avgPartition" })); logger.trace("Learning showdown4Model for player " + player); model.setShowdown4Model(showdownFile.createModel("showdown4", "part4Prob", new String[] { "part0Prob", "part1Prob", "part2Prob", "part3Prob", "part5Prob", "avgPartition" })); logger.trace("Learning showdown5Model for player " + player); model.setShowdown5Model(showdownFile.createModel("showdown5", "part5Prob", new String[] { "part0Prob", "part1Prob", "part2Prob", "part3Prob", "part4Prob", "avgPartition" })); } catch (Exception e) { e.printStackTrace(); } } private void incrementWriteCounter() { writeCounter++; String str = ""; str += preCheckBetFile.getAccuracy() + "\t" + preCheckBetFile.getWindowSize() + "\t"; str += postCheckBetFile.getAccuracy() + "\t" + postCheckBetFile.getWindowSize() + "\t"; str += preFoldCallRaiseFile.getAccuracy() + "\t" + preFoldCallRaiseFile.getWindowSize() + "\t"; str += postFoldCallRaiseFile.getAccuracy() + "\t" + postFoldCallRaiseFile.getWindowSize() + "\t"; // str += showdownFile.getAccuracy() + "\t" + showdownFile.getWindowSize() System.out.println(str); // System.out.println("=" + writeCounter + "="); } }