/** * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. */ package bots.mctsbot.ai.bots.bot.gametree.rollout; import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.TreeMap; import java.util.TreeSet; import org.apache.log4j.Logger; import bots.mctsbot.ai.opponentmodels.OpponentModel; import bots.mctsbot.client.common.gamestate.GameState; import bots.mctsbot.client.common.playerstate.PlayerState; import bots.mctsbot.common.elements.player.PlayerId; import com.biotools.meerkat.Card; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Multiset; import com.google.common.collect.TreeMultiset; import com.google.common.collect.ImmutableMap.Builder; public class BucketRollOut extends RollOutStrategy { private final static Logger logger = Logger.getLogger(BucketRollOut.class); private final OpponentModel model; private final Map<PlayerId, double[]> bucketProbs; private final static int nbBuckets = 6; private final static int nbSamplesPerBucket = 6; public BucketRollOut(GameState gameState, PlayerId botId, OpponentModel model) { super(gameState, botId); this.model = model; Builder<PlayerId, double[]> builder = new ImmutableMap.Builder<PlayerId, double[]>(); for (PlayerState opponentThatCanWin : activeOpponents) { PlayerId playerId = opponentThatCanWin.getPlayerId(); double[] bucketProbs = model.getShowdownProbabilities(gameState, playerId); builder.put(playerId, bucketProbs); } bucketProbs = builder.build(); } //TODO optimize public double doRollOut(int nbCommunitySamples) { boolean traceEnabled = logger.isTraceEnabled(); double totalEV = 0; model.assumeTemporarily(gameState); for (int i = 0; i < nbCommunitySamples; i++) { int communitySampleRank = fixedRank; Set<Integer> usedCommunityAndBotCards = new TreeSet<Integer>(usedFixedCommunityAndBotCards); Set<Integer> usedCommunityCards = new TreeSet<Integer>(); for (int card = 0; card < usedFixedCommunityCards.size(); card++) { usedCommunityCards.add(usedFixedCommunityCards.getCardIndex(card + 1)); } for (int j = 0; j < nbMissingCommunityCards; j++) { Integer communityCard = drawNewCard(usedCommunityAndBotCards); if (traceEnabled) { logger.trace("Evaluating sampled community card " + communityCard); } usedCommunityCards.add(communityCard); communitySampleRank = updateIntermediateRank(communitySampleRank, new Card(communityCard)); } if (traceEnabled) { logger.trace("Evaluating bot cards " + botCard1 + " " + botCard2); } int botRank = getFinalRank(communitySampleRank, botCard1, botCard2); // int minSampleRank = Integer.MAX_VALUE; // int maxSampleRank = Integer.MIN_VALUE; // int sum = 0; Multiset<Integer> ranks = new TreeMultiset<Integer>(); Multiset<Integer> deadRanks = new TreeMultiset<Integer>(); int n = 100; for (int j = 0; j < n; j++) { Set<Integer> handCards = new TreeSet<Integer>(usedCommunityCards); Integer sampleCard1 = drawNewCard(handCards); Integer sampleCard2 = drawNewCard(handCards); int sampleRank = getFinalRank(communitySampleRank, new Card(sampleCard1), new Card(sampleCard2)); ranks.add(sampleRank); if (botCard1.equals(sampleCard1) || botCard1.equals(sampleCard2) || botCard2.equals(sampleCard1) || botCard2.equals(sampleCard2)) { deadRanks.add(sampleRank); } // if(sampleRank<minSampleRank){ // minSampleRank = sampleRank; // } // if(sampleRank>maxSampleRank){ // maxSampleRank = sampleRank; // } // sum += sampleRank; } // double mean = ((double)sum)/n; // double var = calcVariance(ranks, mean); // int averageSampleRank = (int) Math.round(mean); // int sigmaSampleRank = (int) Math.round(Math.sqrt(var)); WinDistribution[] winProbs = calcWinDistributions(botRank, ranks, deadRanks); double[] deadCardWeights = calcDeadCardWeights(ranks, deadRanks); TreeMap<PlayerState, WinDistribution> winDistributions = calcOpponentWinDistributionMap(winProbs, deadCardWeights); int maxDistributed = 0; int botInvestment = botState.getTotalInvestment(); double sampleEV = 0; for (Iterator<PlayerState> iter = winDistributions.keySet().iterator(); iter.hasNext();) { PlayerState opponent = iter.next(); int toDistribute = Math.min(botInvestment, opponent.getTotalInvestment()) - maxDistributed; if (toDistribute > 0) { double pWin = 1; double pNotLose = 1; for (WinDistribution distribution : winDistributions.values()) { //you win when you win from every opponent pWin *= distribution.pWin; //you don't lose when you don't lose from every opponent pNotLose *= distribution.pWin + distribution.pDraw; } sampleEV += toDistribute * pWin; //you draw when you don't lose but don't win everything either; double pDraw = pNotLose - pWin; // assume worst case, with winDistributions.size()+1 drawers //TODO do this better, use rollout or statistics! sampleEV += pDraw * toDistribute / (winDistributions.size() + 1.0); maxDistributed += toDistribute; } iter.remove(); } //get back uncalled investment sampleEV += botInvestment - maxDistributed; totalEV += sampleEV; } model.forgetLastAssumption(); return (1 - gameState.getTableConfiguration().getRake()) * (totalEV / nbCommunitySamples); } private TreeMap<PlayerState, WinDistribution> calcOpponentWinDistributionMap(WinDistribution[] winProbs, double[] deadCardWeights) { TreeMap<PlayerState, WinDistribution> winDistributions = new TreeMap<PlayerState, WinDistribution>(playerComparatorByInvestment); for (PlayerState opponentThatCanWin : activeOpponents) { double[] bucketProb = bucketProbs.get(opponentThatCanWin.getPlayerId()); bucketProb = normalize(multiply(deadCardWeights, bucketProb)); winDistributions.put(opponentThatCanWin, calcOpponentWinDistr(winProbs, bucketProb)); } return winDistributions; } private double[] multiply(double[] a, double[] b) { double[] c = new double[a.length]; for (int i = 0; i < a.length; i++) c[i] = a[i] * b[i]; return c; } private double[] normalize(double[] a) { double[] c = new double[a.length]; double sum = 0; for (int i = 0; i < a.length; i++) sum += a[i]; if (Double.isNaN(sum) || sum == 0 || Double.isInfinite(sum)) { throw new IllegalStateException("Bad probabilities:" + sum + " = " + a); } double invSum = 1 / sum; for (int i = 0; i < a.length; i++) { c[i] = a[i] * invSum; } return c; } private WinDistribution calcOpponentWinDistr(WinDistribution[] winProbs, double[] bucketProbs) { WinDistribution winDistr; double pWin = 0, pDraw = 0, pLose = 0; for (int j = 0; j < bucketProbs.length; j++) { pWin += winProbs[j].pWin * bucketProbs[j]; pDraw += winProbs[j].pDraw * bucketProbs[j]; pLose += winProbs[j].pLose * bucketProbs[j]; } winDistr = new WinDistribution(pWin, pDraw, pLose); return winDistr; } private WinDistribution[] calcWinDistributions(int botRank, Multiset<Integer> ranks, Multiset<Integer> deadRanks) { Iterator<Integer> iter = ranks.iterator(); WinDistribution[] winProbs = new WinDistribution[10]; for (int bucket = 0; bucket < nbBuckets; bucket++) { double winWeight = 0; double drawWeight = 0; double loseWeight = 0; for (int j = 0; j < nbSamplesPerBucket; j++) { int rank = iter.next(); double weight = 1 - deadRanks.count(rank) / ranks.count(rank); if (rank < botRank) { winWeight += weight; } else if (rank > botRank) { loseWeight += weight; } else { drawWeight += weight; } } double nbSamples = winWeight + drawWeight + loseWeight; if (nbSamples == 0) nbSamples = 1; winProbs[bucket] = new WinDistribution(winWeight / nbSamples, drawWeight / nbSamples, loseWeight / nbSamples); } return winProbs; } public static class WinDistribution { //from the perspective of the bot public final double pWin, pDraw, pLose; public WinDistribution(double pWin, double pDraw, double pLose) { this.pWin = pWin; this.pDraw = pDraw; this.pLose = pLose; } @Override public String toString() { return pWin + "/" + pDraw + "/" + pLose; } } private double[] calcDeadCardWeights(Multiset<Integer> ranks, Multiset<Integer> deadRanks) { Iterator<Integer> iter = ranks.iterator(); double[] deadCardWeights = new double[nbBuckets]; for (int bucket = 0; bucket < nbBuckets; bucket++) { double nbDead = 0; for (int j = 0; j < nbSamplesPerBucket; j++) { int rank = iter.next(); double count = ranks.count(rank); double deadCount = deadRanks.count(rank); nbDead += deadCount / count; } deadCardWeights[bucket] = ((nbSamplesPerBucket - nbDead) / nbSamplesPerBucket); } return deadCardWeights; } // private double calcVariance(Multiset<Integer> ranks, double mean) { // double var = 0; // for (Multiset.Entry<Integer> entry : ranks.entrySet()) { // double diff = mean - entry.getElement(); // var += diff * diff * entry.getCount(); // } // var /= (ranks.size()-1); // return var; // } }