/**
* Copyright (C) 2017 Jan Schäfer (jansch@users.sourceforge.net)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jskat.ai.nn.input;
import java.util.ArrayList;
import java.util.List;
import org.jskat.data.Trick;
import org.jskat.player.ImmutablePlayerKnowledge;
import org.jskat.util.Card;
import org.jskat.util.CardDeck;
import org.jskat.util.GameType;
import org.jskat.util.Player;
/**
* Creates input signals for neural networks<br>
* The signals are divided into three parts<br>
* Opponent 1|Neural network player|Opponent 2<br>
* Every player part ist divided into another parts<br>
* Played cards|Unplayed cards|Other information flags
*/
public class SimpleNetworkInputGenerator implements NetworkInputGenerator {
final static int INPUT_LENGTH = 195;
final static int PLAYER_INPUT_LENGTH = 65;
final static int CARD_DECK_INPUT_LENGTH = 32;
final static double ACTIVE = 1.0d;
final static double INACTIVE = 0.0d;
@Override
public double[] getNetInputs(ImmutablePlayerKnowledge knowledge,
Card cardToPlay) {
double[] netInputs = new double[INPUT_LENGTH];
for (int i = 0; i < INPUT_LENGTH; i++) {
netInputs[i] = INACTIVE;
}
// set played cards (32 neurons per player)
setPlayedCardsInput(netInputs, knowledge);
// set unplayed cards (32 neurons per player)
setUnplayedCardsInput(netInputs, knowledge);
// set card to be played (see played cards)
setCardToBePlayedInput(netInputs, cardToPlay);
// set game declarer (1 neuron per player)
setDeclarerInput(netInputs, knowledge);
return netInputs;
}
private void setDeclarerInput(double[] netInputs,
ImmutablePlayerKnowledge knowledge) {
if (!GameType.RAMSCH.equals(knowledge.getGameType())) {
// in Ramsch games there is no declarer
Player position = knowledge.getPlayerPosition();
Player declarer = knowledge.getDeclarer();
int index = -1;
if (position.getLeftNeighbor() == declarer) {
index = 2 * CARD_DECK_INPUT_LENGTH;
} else if (position == declarer) {
index = PLAYER_INPUT_LENGTH + 2 * CARD_DECK_INPUT_LENGTH;
} else if (position.getRightNeighbor() == declarer) {
index = 2 * PLAYER_INPUT_LENGTH + 2 * CARD_DECK_INPUT_LENGTH;
}
netInputs[index] = ACTIVE;
}
}
private static void setPlayedCardsInput(double[] netInputs,
ImmutablePlayerKnowledge knowledge) {
List<Trick> trickList = new ArrayList<Trick>();
trickList.addAll(knowledge.getCompletedTricks());
trickList.add(knowledge.getCurrentTrick());
for (Trick trick : trickList) {
Player position = knowledge.getPlayerPosition();
Player trickPlayer = trick.getForeHand();
for (Card card : trick.getCardList()) {
setPlayedCardsInput(netInputs, knowledge.getGameType(),
position, trickPlayer, card, ACTIVE);
trickPlayer = trickPlayer.getLeftNeighbor();
}
}
}
private static void setPlayedCardsInput(double[] netInputs,
GameType gameType, Player position, Player trickPlayer, Card card,
double activationValue) {
int cardIndex = getNetInputIndex(card);
// using offset of 1 because of declarer flag
int index = -1;
if (position.getLeftNeighbor() == trickPlayer) {
index = cardIndex;
} else if (position == trickPlayer) {
index = PLAYER_INPUT_LENGTH + cardIndex;
} else if (position.getRightNeighbor() == trickPlayer) {
index = 2 * PLAYER_INPUT_LENGTH + cardIndex;
}
netInputs[index] = activationValue;
}
private void setUnplayedCardsInput(double[] netInputs,
ImmutablePlayerKnowledge knowledge) {
for (Card card : new CardDeck()) {
setUnplayedCardsInput(netInputs, knowledge, card);
}
}
private void setUnplayedCardsInput(double[] netInputs,
ImmutablePlayerKnowledge knowledge, Card card) {
Player leftOpponent = knowledge.getPlayerPosition().getLeftNeighbor();
Player rightOpponent = knowledge.getPlayerPosition().getRightNeighbor();
int index = getNetInputIndex(card);
// inputs for left opponent
if (knowledge.couldHaveCard(leftOpponent, card)) {
netInputs[CARD_DECK_INPUT_LENGTH + index] = ACTIVE;
}
// inputs for player
if (knowledge.couldHaveCard(knowledge.getPlayerPosition(), card)) {
netInputs[PLAYER_INPUT_LENGTH + CARD_DECK_INPUT_LENGTH
+ index] = ACTIVE;
}
// inputs for right opponent
if (knowledge.couldHaveCard(rightOpponent, card)) {
netInputs[2 * PLAYER_INPUT_LENGTH + CARD_DECK_INPUT_LENGTH
+ index] = ACTIVE;
}
}
private void setCardToBePlayedInput(double[] netInputs, Card cardToPlay) {
// Card to be played is set into played cards inputs
int index = PLAYER_INPUT_LENGTH + getNetInputIndex(cardToPlay);
netInputs[index] = ACTIVE;
}
private static int getNetInputIndex(final Card card) {
return card.getSuit().getSuitOrder() * 8 + card.getNullOrder();
}
}