package bots.mctsbot.ai.opponentmodels.weka;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.Logger;
import bots.mctsbot.ai.bots.bot.gametree.mcts.MCTSBot;
import bots.mctsbot.ai.bots.bot.gametree.mcts.nodes.INode;
import bots.mctsbot.ai.opponentmodels.OpponentModel;
import bots.mctsbot.ai.opponentmodels.listeners.OpponentModelListener;
import bots.mctsbot.client.common.gamestate.GameState;
import bots.mctsbot.client.common.playerstate.PlayerState;
import bots.mctsbot.common.elements.player.PlayerId;
import bots.mctsbot.common.util.Pair;
import bots.mctsbot.common.util.Triple;
/**
* This OpponentModel delegates to a provided default {@link WekaModel} for its opponent-model.
* In addition it observes the game and (configured by {@link WekaOptions}) replaces
* the opponent-model for each villain after enough data has been collected.
*
*/
public class WekaLearningModel implements OpponentModel {
protected static final Logger logger = Logger.getLogger(WekaLearningModel.class);
private PlayerTrackingVisitor permanentVisitor;
private ActionTrackingVisitor actionTrackingVisitor;
private final Deque<PlayerTrackingVisitor> visitors = new ArrayDeque<PlayerTrackingVisitor>();
Map<PlayerId, WekaRegressionModel> opponentModels = new HashMap<PlayerId, WekaRegressionModel>();
private final WekaRegressionModel defaultModel;
private final WekaOptions config;
private final PlayerId bot;
private final OpponentModelListener[] listeners;
private INode node;
public WekaLearningModel(PlayerId botId, WekaRegressionModel defaultModel, WekaOptions config, OpponentModelListener... listeners) {
this.permanentVisitor = new PlayerTrackingVisitor(this);
this.visitors.add(permanentVisitor);
this.defaultModel = defaultModel;
this.config = config;
this.bot = botId;
this.listeners = listeners;
for (int i = 0; i < listeners.length; i++)
listeners[i].setOpponentModel(this);
if (config.useOnlineLearning()) {
this.actionTrackingVisitor = new ActionTrackingVisitor(this, bot);
}
}
public WekaOptions getConfig() {
return config;
}
// these methods are used by KullbackLeiblerListener
// TODO: better design (this is messy)
public Map<PlayerId, WekaRegressionModel> getOpponentModels() {
return opponentModels;
}
public WekaRegressionModel getDefaultModel() {
return defaultModel;
}
public Propositionalizer getCurrentGamePropositionalizer() {
return visitors.peek().getPropz();
}
// *************************************************
@Override
public void assumePermanently(GameState gameState) {
// make sure we have created Models for all players
Set<PlayerState> seatedPlayers = gameState.getAllSeatedPlayers();
for (PlayerState playerState : seatedPlayers) {
getWekaModel(playerState.getPlayerId());
}
permanentVisitor.readHistory(gameState);
if (actionTrackingVisitor != null) {
actionTrackingVisitor.readHistory(gameState);
}
}
@Override
public void assumeTemporarily(GameState gameState) {
PlayerTrackingVisitor root = visitors.peek();
PlayerTrackingVisitor clonedTopVisitor = root.clone();
clonedTopVisitor.readHistory(gameState);
visitors.push(clonedTopVisitor);
}
@Override
public void forgetLastAssumption() {
visitors.pop();
// the permanentVisitor should never be popped
if (visitors.isEmpty()) {
throw new IllegalStateException("'forgetAssumption' was called more often than 'assumeTemporarily'");
}
}
private WekaRegressionModel getWekaModel(PlayerId actor) {
WekaRegressionModel model = opponentModels.get(actor);
if (model == null) {
model = new WekaRegressionModel(defaultModel);
if (config.useOnlineLearning() && !actor.equals(bot)) {
opponentModels.put(actor, model);
actionTrackingVisitor.getPropz().addPlayer(actor, new ARFFPlayer(actor, model, config, actionTrackingVisitor));
}
}
return model;
}
public ARFFPlayer getPlayer(PlayerId actor) {
if (config.useOnlineLearning() && !actor.equals(bot)) {
return actionTrackingVisitor.getPropz().getARFF(actor);
} else
return null;
}
public double getPlayerAccuracy(PlayerId actor) {
if (config.useOnlineLearning() && !actor.equals(bot)) {
return actionTrackingVisitor.getAccuracy(actor);
} else
return 0.0;
}
@Override
public Pair<Double, Double> getCheckBetProbabilities(GameState gameState, PlayerId actor) {
for (int i = 0; i < listeners.length; i++)
listeners[i].onGetCheckProbabilities(gameState, actor);
return getWekaModel(actor).getCheckBetProbabilities(actor, getCurrentGamePropositionalizer());
}
@Override
public Triple<Double, Double, Double> getFoldCallRaiseProbabilities(GameState gameState, PlayerId actor) {
for (int i = 0; i < listeners.length; i++)
listeners[i].onGetFoldCallRaiseProbabilities(gameState, actor);
return getWekaModel(actor).getFoldCallRaiseProbabilities(actor, getCurrentGamePropositionalizer());
}
@Override
public double[] getShowdownProbabilities(GameState gameState, PlayerId actor) throws UnsupportedOperationException {
for (int i = 0; i < listeners.length; i++)
listeners[i].onGetShowdownProbilities(gameState, actor);
return getWekaModel(actor).getShowdownProbabilities(actor, getCurrentGamePropositionalizer());
}
/**
* Saves the node with the last move played by {@link MCTSBot}.
* Is used to get probabilities of the opponents moves in order
* to calculate the accuracy of predictions by the opponentmodel.
* @param node INode containing last action by MCTSBot
*/
@Override
public void setChosenNode(INode node) {
this.node = node;
}
@Override
public INode getChosenNode() {
return this.node;
}
@Override
public PlayerId getBotId() {
return bot;
}
}