package bots.mctsbot.ai.bots.bot.gametree.tls.tests; import bots.mctsbot.ai.bots.bot.gametree.action.BetAction; import bots.mctsbot.ai.bots.bot.gametree.action.CallAction; import bots.mctsbot.ai.bots.bot.gametree.action.CheckAction; import bots.mctsbot.ai.bots.bot.gametree.action.DoNothingAction; import bots.mctsbot.ai.bots.bot.gametree.action.FoldAction; import bots.mctsbot.ai.bots.bot.gametree.action.RaiseAction; import bots.mctsbot.ai.bots.bot.gametree.action.SearchBotAction; import bots.mctsbot.ai.bots.bot.gametree.tls.nodes.AbstractTLSNode; import bots.mctsbot.ai.bots.util.RunningStats; public class Test { private static final Class[] line = { DoNothingAction.class, FoldAction.class, CheckAction.class, CallAction.class, BetAction.class, RaiseAction.class }; private final SearchBotAction testAction; private final int testIndex; private final RunningStats failStats = new RunningStats(); private final RunningStats successStats = new RunningStats(); private final AbstractTLSNode node; public Test(SearchBotAction action, AbstractTLSNode node) { this.testAction = action; testIndex = findIndex(action); this.node = node; } public SearchBotAction getTestAction() { return testAction; } public void updateStats(SearchBotAction action, double value) { if (succeeds(action)) successStats.add(value); else failStats.add(value); } public boolean succeeds(SearchBotAction action) { if (findIndex(action) < testIndex) return false; if (testAction instanceof BetAction && action instanceof BetAction) return ((BetAction) action).amount >= ((BetAction) testAction).amount; if (testAction instanceof RaiseAction && action instanceof RaiseAction) return ((RaiseAction) action).amount >= ((RaiseAction) testAction).amount; return true; } private int findIndex(SearchBotAction action) { for (int i = 0; i < line.length; i++) { if (line[i] == action.getClass()) return i; } throw new IllegalArgumentException("The runtime class of the argument does not match any class in the action line"); } public double getSDR() { return node.getStdDev() - failStats.getNbSamples() / node.getNbSamples() * failStats.getStdDev() - successStats.getNbSamples() / node.getNbSamples() * successStats.getStdDev(); } public RunningStats getLeftStats() { return failStats; } public RunningStats getRightStats() { return successStats; } }