package net.demilich.metastone.game.behaviour.learning;
import java.util.List;
import net.demilich.metastone.game.GameContext;
import net.demilich.metastone.game.Player;
import net.demilich.metastone.game.behaviour.neutralnetwork.HiddenUnit;
import net.demilich.metastone.game.behaviour.neutralnetwork.NeuralNetwork;
import net.demilich.metastone.game.entities.minions.Minion;
import net.demilich.metastone.utils.MathUtils;
public class Brain implements IBrain {
//private static Logger logger = LoggerFactory.getLogger(Brain.class);
private static final int INPUTS = 15;
private static final int HIDDEN_NEURONS = 40;
private static final int OUTPUTS = 1;
private static final double ALPHA = 0.1;
private static final double BETA = 0.1;
private static final double LAMBDA = 0.5;
private boolean learning;
private NeuralNetwork neuralNetwork;
private double[][] ew;
private double[][][] ev;
public Brain() {
neuralNetwork = new NeuralNetwork(INPUTS, new int[] { HIDDEN_NEURONS, OUTPUTS });
ew = new double[HIDDEN_NEURONS][OUTPUTS];
ev = new double[INPUTS][HIDDEN_NEURONS][OUTPUTS];
}
private void backPropagation(double[] in, double[] out, double[] expected) {
// compute eligability traces
for (int j = 0; j < neuralNetwork.hidden[0].length; j++) {
for (int k = 0; k < out.length; k++) {
ew[j][k] = LAMBDA * ew[j][k] + gradient(neuralNetwork.hidden[1][k]) * neuralNetwork.hidden[0][j].getValue();
for (int i = 0; i < in.length; i++) {
ev[i][j][k] = LAMBDA * ev[i][j][k] + gradient(neuralNetwork.hidden[1][k]) * neuralNetwork.hidden[1][k].weights[j]
* gradient(neuralNetwork.hidden[0][j]) * in[i];
}
}
}
double[] error = new double[out.length];
for (int k = 0; k < out.length; k++) {
error[k] = expected[k] - out[k];
}
for (int j = 0; j < neuralNetwork.hidden[0].length; j++) {
for (int k = 0; k < out.length; k++) {
// weight from j to k, shown with learning param of BETA
neuralNetwork.hidden[1][k].weights[j] += BETA * error[k] * ew[j][k];
for (int i = 0; i < in.length; i++) {
neuralNetwork.hidden[0][j].weights[i] += ALPHA * error[k] * ev[i][j][k];
}
}
}
}
private void encodePlayer(Player player, double[] data, int offset) {
List<Minion> minions = player.getMinions();
int totalMinionAttack = 0;
int totalMinionHp = 0;
for (int i = 0; i < 7; i++) {
Minion minion = i < minions.size() ? player.getMinions().get(i) : null;
totalMinionAttack += minion != null ? minion.getAttack() : 0;
totalMinionHp += minion != null ? minion.getHp() : 0;
}
data[offset++] = minions.size() / 7.0;
data[offset++] = MathUtils.clamp01(totalMinionAttack / 40.0);
data[offset++] = MathUtils.clamp01(totalMinionHp / 40.0);
data[offset++] = MathUtils.clamp01(player.getHero().getAttack() / 10.0);
data[offset++] = MathUtils.clamp01((player.getHero().getHp() + player.getHero().getArmor()) / 40.0);
data[offset++] = player.getHand().getCount() / 10.0;
data[offset++] = MathUtils.clamp01(player.getDeck().getCount() / 30.0);
}
private double[] gameStateToInput(GameContext context, int playerId) {
double[] input = new double[INPUTS];
Player player = context.getPlayer(playerId);
Player opponent = context.getOpponent(player);
encodePlayer(player, input, 0);
encodePlayer(opponent, input, INPUTS / 2);
input[INPUTS - 1] = MathUtils.clamp01(context.getTurn() / 10.0);
return input;
}
@Override
public double getEstimatedUtility(double[] output) {
return output[0];
}
@Override
public double[] getOutput(GameContext context, int playerId) {
double[] input = gameStateToInput(context, playerId);
return neuralNetwork.getValue(input);
}
private double gradient(HiddenUnit hiddenUnit) {
return hiddenUnit.getValue() * (1.0 - hiddenUnit.getValue());
}
@Override
public boolean isLearning() {
return learning;
}
@Override
public void learn(GameContext originalState, int playerId, double[] nextOutput, double reward) {
double[] currentInput = gameStateToInput(originalState, playerId);
double[] currentOutput = getOutput(originalState, playerId);
for (int i = 0; i < nextOutput.length; i++) {
nextOutput[i] += reward;
}
backPropagation(currentInput, currentOutput, nextOutput);
}
public void load(String path) {
// try {
// neuralNetwork = NeuralNetwork.readFrom(path);
// logger.info("Saved brain data loaded");
// } catch (ClassNotFoundException e) {
// e.printStackTrace();
// } catch (IOException e) {
// logger.info("Brain data not found, using unlearned neural network");
// }
}
/*private void printWeights() {
for (int i = 0; i < neuralNetwork.hidden.length; i++) {
for (int j = 0; j < neuralNetwork.hidden[i].length; j++) {
for (int k = 0; k < neuralNetwork.hidden[i][j].weights.length; k++) {
System.out.println(Arrays.toString(neuralNetwork.hidden[i][j].weights));
}
}
}
}*/
public void save(String path) {
// try {
// neuralNetwork.writeTo(path);
// logger.info("Brain data saved to: " + path);
// } catch (IOException e) {
// }
}
@Override
public void setLearning(boolean learning) {
this.learning = learning;
}
public void wipeEligabilityTraces() {
ew = new double[HIDDEN_NEURONS][OUTPUTS];
ev = new double[INPUTS][HIDDEN_NEURONS][OUTPUTS];
}
}