/**
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*/
package bots.mctsbot.ai.bots.bot.gametree.search.expander;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import org.apache.log4j.Logger;
import bots.mctsbot.ai.bots.bot.gametree.action.ActionWrapper;
import bots.mctsbot.ai.bots.bot.gametree.action.BetAction;
import bots.mctsbot.ai.bots.bot.gametree.action.ProbabilityAction;
import bots.mctsbot.ai.bots.bot.gametree.action.RaiseAction;
import bots.mctsbot.ai.bots.bot.gametree.search.BotActionNode;
import bots.mctsbot.ai.bots.bot.gametree.search.GameTreeNode;
import bots.mctsbot.ai.bots.bot.gametree.search.InnerGameTreeNode;
import bots.mctsbot.ai.bots.bot.gametree.search.expander.sampling.Sampler;
import bots.mctsbot.common.util.Pair;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Multiset;
import com.google.common.collect.TreeMultiset;
import com.google.common.collect.Multiset.Entry;
public class SamplingExpander extends TokenExpander {
private static final int Max_Granularity = 300;
private final static Logger logger = Logger.getLogger(BotActionNode.class);
private final Random random = new Random();
public SamplingExpander(InnerGameTreeNode node, int tokens, Sampler sampler) {
super(node, tokens, sampler);
}
@Override
public List<Pair<ActionWrapper, GameTreeNode>> getChildren(boolean uniformTokens) {
List<Pair<ActionWrapper, WeightedNode>> weightedChildren = getWeightedChildren(uniformTokens);
List<Pair<ActionWrapper, GameTreeNode>> children = new ArrayList<Pair<ActionWrapper, GameTreeNode>>(weightedChildren.size());
for (Pair<ActionWrapper, WeightedNode> wpair : weightedChildren) {
children.add(new Pair<ActionWrapper, GameTreeNode>(wpair.getLeft(), wpair.getRight().getNode()));
}
return children;
}
public List<Pair<ActionWrapper, WeightedNode>> getWeightedChildren(boolean uniformTokens) {
List<ProbabilityAction> probActions = new ArrayList<ProbabilityAction>(getProbabilityActions());
double[] cumulProb = new double[probActions.size()];
for (int i = 0; i < probActions.size(); i++) {
cumulProb[i] = (i > 0 ? cumulProb[i - 1] : 0) + probActions.get(i).getProbability();
}
if (logger.isTraceEnabled()) {
for (int i = 0; i < probActions.size(); i++) {
logger.trace("cumulProb[" + i + "]=" + cumulProb[i] + " for action " + probActions.get(i));
}
}
// ordening for sexy debugging output
Multiset<ProbabilityAction> samples = TreeMultiset.create(new Comparator<ProbabilityAction>() {
@Override
public int compare(ProbabilityAction o1, ProbabilityAction o2) {
if (o2.getProbability() < o1.getProbability()) {
return -1;
}
if (o2.getProbability() > o1.getProbability()) {
return 1;
}
if (o1.getAction() instanceof RaiseAction && o2.getAction() instanceof RaiseAction) {
return ((RaiseAction) o2.getAction()).amount - ((RaiseAction) o1.getAction()).amount;
}
if (o1.getAction() instanceof BetAction && o2.getAction() instanceof BetAction) {
return ((BetAction) o2.getAction()).amount - ((BetAction) o1.getAction()).amount;
}
// if probabilities are equal for different classes,
// objects are NOT equal per se
// go alphabetically?
return o1.toString().compareTo(o2.toString());
}
});
// Multiset<ProbabilityAction> samples = new
// HashMultiset<ProbabilityAction>();
int nbSamples = Math.min(Max_Granularity, tokens);
for (int i = 0; i < nbSamples; i++) {
ProbabilityAction sampledAction = sampleAction(probActions, cumulProb);
samples.add(sampledAction);
}
Set<Entry<ProbabilityAction>> entrySet = samples.entrySet();
ImmutableList.Builder<Pair<ActionWrapper, WeightedNode>> childrenBuilder = ImmutableList.builder();
for (Entry<ProbabilityAction> entry : entrySet) {
int tokensShare = uniformTokens ? tokens / entrySet.size() : tokens * entry.getCount() / nbSamples;
//
childrenBuilder.add(new Pair<ActionWrapper, WeightedNode>(entry.getElement(), new WeightedNode(node.getChildAfter(entry.getElement(), tokensShare),
entry.getCount() / (double) nbSamples)));
}
return childrenBuilder.build();
}
private ProbabilityAction sampleAction(List<ProbabilityAction> probActions, double[] cumulProb) {
double randDouble = random.nextDouble();
for (int i = 0; i < cumulProb.length; i++) {
if (randDouble < cumulProb[i]) {
if (logger.isTraceEnabled()) {
logger.trace("random " + randDouble + " assigned to " + probActions.get(i));
}
return probActions.get(i);
}
}
return probActions.get(probActions.size() - 1);
}
public static class Factory implements TokenExpander.Factory {
public SamplingExpander create(InnerGameTreeNode node, int tokens, Sampler sampler) {
return new SamplingExpander(node, tokens, sampler);
}
}
}