package net.demilich.metastone.game.behaviour.mcts;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import net.demilich.metastone.game.GameContext;
import net.demilich.metastone.game.Player;
import net.demilich.metastone.game.actions.GameAction;
import net.demilich.metastone.game.behaviour.PlayRandomBehaviour;
class Node {
private GameContext state;
private List<GameAction> validTransitions;
private final List<Node> children = new LinkedList<>();
private final GameAction incomingAction;
private int visits;
private int score;
private final int player;
public Node(GameAction incomingAction, int player) {
this.incomingAction = incomingAction;
this.player = player;
}
private boolean canFurtherExpanded() {
return !validTransitions.isEmpty();
}
private Node expand() {
GameAction action = validTransitions.remove(0);
GameContext newState = state.clone();
try {
newState.getLogic().performGameAction(newState.getActivePlayer().getId(), action);
} catch (Exception e) {
System.err.println("Exception on action: " + action + " state decided: " + state.gameDecided());
e.printStackTrace();
throw e;
}
Node child = new Node(action, getPlayer());
child.initState(newState, newState.getValidActions());
children.add(child);
return child;
}
public GameAction getBestAction() {
GameAction best = null;
int bestScore = Integer.MIN_VALUE;
for (Node node : children) {
if (node.getScore() > bestScore) {
best = node.incomingAction;
bestScore = node.getScore();
}
}
return best;
}
public List<Node> getChildren() {
return children;
}
public int getPlayer() {
return player;
}
public int getScore() {
return score;
}
public GameContext getState() {
return state;
}
public int getVisits() {
return visits;
}
public void initState(GameContext state, List<GameAction> validActions) {
this.state = state.clone();
this.validTransitions = new ArrayList<GameAction>(validActions);
}
public boolean isExpandable() {
if (validTransitions.isEmpty()) {
return false;
}
if (state.gameDecided()) {
return false;
}
return getChildren().size() < validTransitions.size();
}
public boolean isLeaf() {
return children == null || children.isEmpty();
}
private boolean isTerminal() {
return state.gameDecided();
}
public void process(ITreePolicy treePolicy) {
List<Node> visited = new LinkedList<Node>();
Node current = this;
visited.add(this);
while (!current.isTerminal()) {
if (current.canFurtherExpanded()) {
current = current.expand();
visited.add(current);
break;
} else {
current = treePolicy.select(current);
visited.add(current);
}
}
int value = rollOut(current);
for (Node node : visited) {
node.updateStats(value);
}
}
public int rollOut(Node node) {
if (node.getState().gameDecided()) {
GameContext state = node.getState();
return state.getWinningPlayerId() == getPlayer() ? 1 : 0;
}
GameContext simulation = node.getState().clone();
for (Player player : simulation.getPlayers()) {
player.setBehaviour(new PlayRandomBehaviour());
}
simulation.playTurn();
return simulation.getWinningPlayerId() == getPlayer() ? 1 : 0;
}
private void updateStats(int value) {
visits++;
score += value;
}
}