/** * Copyright (C) 2017 Jan Schäfer (jansch@users.sourceforge.net) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.jskat.ai.nn; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import org.jskat.ai.AbstractAIPlayer; import org.jskat.ai.nn.data.SkatNetworks; import org.jskat.ai.nn.input.GenericNetworkInputGenerator; import org.jskat.ai.nn.input.NetworkInputGenerator; import org.jskat.ai.nn.util.INeuralNetwork; import org.jskat.data.GameAnnouncement; import org.jskat.data.GameAnnouncement.GameAnnouncementFactory; import org.jskat.data.GameSummary; import org.jskat.data.SkatGameData; import org.jskat.data.SkatGameResult; import org.jskat.player.JSkatPlayer; import org.jskat.util.Card; import org.jskat.util.CardList; import org.jskat.util.GameType; import org.jskat.util.Player; import org.jskat.util.rule.SkatRule; import org.jskat.util.rule.SkatRuleFactory; import org.slf4j.Logger; /** * JSkat player using neural network */ public class AIPlayerNN extends AbstractAIPlayer { private final static Long MAX_SIMULATIONS_DISCARDING = 2000L; private final static Long MAX_TIME_DISCARDING = 5000L; private final static Long MAX_SIMULATIONS_BIDDING = 100L; private final static Long MAX_SIMULATIONS_HAND_GAME = 500L; private final static Double MIN_WON_RATE_FOR_BIDDING = 0.6; private final static Double MIN_WON_RATE_FOR_DISCARDING = 0.75; private final static Double MIN_WON_RATE_FOR_HAND_GAME = 0.95; public final static Double IDEAL_WON = 1.0; public final static Double IDEAL_LOST = 0.0; public final static Double EPSILON = 0.2; // FIXME (jan 10.03.2012) code duplication with NNTrainer private static boolean isRamschGameWon(final GameSummary gameSummary, final Player currPlayer) { boolean ramschGameWon = false; final int playerPoints = gameSummary.getPlayerPoints(currPlayer); int highestPlayerPoints = 0; for (final Player player : Player.values()) { final int currPlayerPoints = gameSummary.getPlayerPoints(player); if (currPlayerPoints > highestPlayerPoints) { highestPlayerPoints = currPlayerPoints; } } if (highestPlayerPoints > playerPoints) { ramschGameWon = true; } return ramschGameWon; } private final DecimalFormat formatter = new DecimalFormat("0.00000000000000000"); //$NON-NLS-1$ private final GameSimulator2 gameSimulator2; private final NetworkInputGenerator inputGenerator; private final static Random RANDOM = new Random(); private final List<double[]> allInputs = new ArrayList<>(); private GameType bestGameTypeFromDiscarding; private boolean isLearning = false; private double lastAvgNetworkError = 0.0; private final List<GameType> feasibleGameTypes = new ArrayList<GameType>(); /** * Constructor */ public AIPlayerNN() { this("unknown", null); //$NON-NLS-1$ } /** * Constructor * * @param logger * Logger to be used, allows to set NOPLogger from outside */ public AIPlayerNN(final Logger logger) { this("unknown", logger); } /** * Creates a new instance of AIPlayerNN * * @param newPlayerName * Player's name * @param logger * Logger to be used, allows to set NOPLogger from outside */ public AIPlayerNN(final String newPlayerName, final Logger logger) { log.debug("Constructing new AIPlayerNN"); //$NON-NLS-1$ setPlayerName(newPlayerName); if (logger != null) { log = logger; } inputGenerator = new GenericNetworkInputGenerator(); gameSimulator2 = new GameSimulator2(); for (final GameType gameType : GameType.values()) { if (gameType != GameType.RAMSCH && gameType != GameType.PASSED_IN) { feasibleGameTypes.add(gameType); } } } /** * @see JSkatPlayer#announceGame() */ @Override public GameAnnouncement announceGame() { log.debug("position: " + knowledge.getPlayerPosition()); //$NON-NLS-1$ log.debug("bids: " + knowledge.getHighestBid(Player.FOREHAND) + //$NON-NLS-1$ " " + knowledge.getHighestBid(Player.MIDDLEHAND) + //$NON-NLS-1$ " " + knowledge.getHighestBid(Player.REARHAND)); //$NON-NLS-1$ final GameAnnouncementFactory factory = GameAnnouncement.getFactory(); if (bestGameTypeFromDiscarding != null) { // use game type from discarding factory.setGameType(bestGameTypeFromDiscarding); } factory.setHand(knowledge.isHandGame()); // FIXME (jan 17.01.2011) setting ouvert and schneider/schwarz // newGame.setOuvert(rand.nextBoolean()); final GameAnnouncement newGame = factory.getAnnouncement(); log.debug("Announcing: " + newGame); //$NON-NLS-1$ return newGame; } /** * @see JSkatPlayer#bidMore(int) */ @Override public Integer bidMore(final int nextBidValue) { int result = -1; if (isAnyGamePossible(nextBidValue)) { result = nextBidValue; } return result; } @Override public Boolean callContra() { // TODO Auto-generated method stub return false; } @Override public Boolean callRe() { // TODO Auto-generated method stub return false; } /** * @see org.jskat.player.JSkatPlayer#finalizeGame() */ @Override public void finalizeGame() { if (isLearning && allInputs.size() > 0) { // adjust neural networks // from last trick to first trick adjustNeuralNetworks(allInputs); } } /** * @see JSkatPlayer#discardSkat() */ @Override public CardList getCardsToDiscard() { final CardList cards = knowledge.getOwnCards(); log.debug("Player cards before discarding: " + knowledge.getOwnCards()); //$NON-NLS-1$ final List<GameType> filteredGameTypes = filterFeasibleGameTypes( knowledge.getHighestBid(knowledge.getPlayerPosition()).intValue()); gameSimulator2.reset(); // create all possible discards int simCount = 0; for (int i = 0; i < cards.size() - 1; i++) { for (int j = i + 1; j < cards.size(); j++) { simCount++; final CardList simCards = new CardList(cards); final CardList currSkat = new CardList(simCards.get(i), simCards.get(j)); simCards.removeAll(currSkat); log.debug("Discard simulation no. " + simCount + ": skat " + currSkat); for (final GameType gameType : filteredGameTypes) { gameSimulator2.add(new GameSimulation(gameType, knowledge.getPlayerPosition(), simCards, currSkat)); } } } final GameSimulation bestSimulation = gameSimulator2.simulateMaxEpisodes(1000L); bestGameTypeFromDiscarding = bestSimulation.getGameType(); log.warn("Simulated " + bestSimulation.getEpisodes() + " episodes with highest won rate of " + bestSimulation.getWonRate() + " discarded cards " + bestSimulation.getSkatCards()); return bestSimulation.getSkatCards(); } private CardList getRandomEntry(final List<CardList> possibleSkats) { return possibleSkats.get(RANDOM.nextInt(possibleSkats.size())); } /** * Gets the last average network error * * @return Last average network error */ public double getLastAvgNetworkError() { return lastAvgNetworkError; } /** * @see JSkatPlayer#holdBid(int) */ @Override public Boolean holdBid(final int currBidValue) { return isAnyGamePossible(currBidValue); } /** * @see JSkatPlayer#pickUpSkat() */ @Override public Boolean pickUpSkat() { log.warn("Check hand game or pick up skat..."); gameSimulator2.reset(); final List<GameType> filteredGameTypes = filterFeasibleGameTypes( knowledge.getHighestBid(knowledge.getPlayerPosition()).intValue()); for (final GameType gameType : filteredGameTypes) { gameSimulator2.add(new GameSimulation(gameType, knowledge.getPlayerPosition(), knowledge.getOwnCards())); } final GameSimulation bestSimulation = gameSimulator2.simulateMaxEpisodes(MAX_SIMULATIONS_HAND_GAME); log.warn("Simulated " + bestSimulation.getEpisodes() + " episodes with best won rate of " + bestSimulation.getWonRate()); if (bestSimulation.getWonRate() >= MIN_WON_RATE_FOR_HAND_GAME) { log.warn("Min won rate reached. Playing hand..."); bestGameTypeFromDiscarding = bestSimulation.getGameType(); return false; } log.warn("Min won rate not reached. Picking up skat..."); return true; } /** * @see JSkatPlayer#playCard() */ @Override public Card playCard() { int bestCardIndex = -1; log.debug('\n' + knowledge.toString()); // first find all possible cards final CardList possibleCards = getPlayableCards(knowledge.getTrickCards()); log.debug("found " + possibleCards.size() + " possible cards: " //$NON-NLS-1$//$NON-NLS-2$ + possibleCards); final Map<Card, double[]> cardInputs = new HashMap<Card, double[]>(); final INeuralNetwork net = SkatNetworks.getNetwork(knowledge.getGameAnnouncement().getGameType(), isDeclarer(), knowledge.getCurrentTrick().getTrickNumberInGame()); final CardList bestCards = new CardList(); final CardList highestOutputCards = new CardList(); Double highestOutput = Double.NEGATIVE_INFINITY; for (final Card card : possibleCards) { log.debug("Testing card " + card); //$NON-NLS-1$ final double[] inputs = inputGenerator.getNetInputs(knowledge, card); cardInputs.put(card, inputs); final Double currOutput = net.getPredictedOutcome(inputs); log.warn("net output for card " + card + ": " //$NON-NLS-1$ + formatter.format(currOutput)); if (currOutput > (IDEAL_WON - EPSILON)) { bestCards.add(card); } if (currOutput > highestOutput && !formatter.format(currOutput).equals(formatter.format(highestOutput))) { highestOutput = currOutput; highestOutputCards.clear(); highestOutputCards.add(card); } else if (currOutput == highestOutput || formatter.format(currOutput).equals(formatter.format(highestOutput))) { highestOutputCards.add(card); } } if (bestCards.size() > 0) { // get random card out of the best cards bestCardIndex = chooseRandomCard(possibleCards, bestCards); log.warn("Trick " + (knowledge.getNoOfTricks() + 1) //$NON-NLS-1$ + ": Found best cards. Choosing random from " //$NON-NLS-1$ + bestCards.size() + " out of " + possibleCards.size() + ": " //$NON-NLS-1$ //$NON-NLS-2$ + possibleCards.get(bestCardIndex)); } else { // no best card, get card with best output bestCardIndex = chooseRandomCard(possibleCards, highestOutputCards); log.warn("Trick " + (knowledge.getNoOfTricks() + 1) //$NON-NLS-1$ + ": Found no best cards. Choosing card from " //$NON-NLS-1$ + highestOutputCards.size() + " out of " + possibleCards.size() + " with highest output: " + possibleCards.get(bestCardIndex)); // no best card, get random card out of all cards // bestCardIndex = chooseRandomCard(possibleCards, possibleCards); } // store parameters for the card to play // for adjustment of weights after the game storeInputParameters(cardInputs.get(possibleCards.get(bestCardIndex))); log.debug("choosing card " + bestCardIndex); //$NON-NLS-1$ log.debug("as player " + knowledge.getPlayerPosition() + ": " //$NON-NLS-1$//$NON-NLS-2$ + possibleCards.get(bestCardIndex)); return possibleCards.get(bestCardIndex); } @Override public Boolean playGrandHand() { return Boolean.FALSE; } /** * @see org.jskat.player.JSkatPlayer#preparateForNewGame() */ @Override public void preparateForNewGame() { bestGameTypeFromDiscarding = null; allInputs.clear(); } /** * Sets the player into learning mode * * @param newIsLearning * TRUE if the player should learn during play */ public void setIsLearning(final boolean newIsLearning) { isLearning = newIsLearning; } /** * @see org.jskat.player.AbstractJSkatPlayer#startGame() */ @Override public void startGame() { // CHECK Auto-generated method stub } private void adjustNeuralNetworks(final List<double[]> inputs) { double output = 0.0d; if (!GameType.PASSED_IN.equals(knowledge.getGameType())) { if (GameType.RAMSCH.equals(knowledge.getGameType())) { if (isRamschGameWon(gameSummary, knowledge.getPlayerPosition())) { output = IDEAL_WON; } else { output = IDEAL_LOST; } } else { if (isDeclarer()) { if (gameSummary.isGameWon()) { output = IDEAL_WON; } else { output = IDEAL_LOST; } } else { if (gameSummary.isGameWon()) { output = IDEAL_LOST; } else { output = IDEAL_WON; } } } log.warn("Learning output: " + output); final double[][] inputsArray = new double[inputs.size()][]; final double[][] outputsArray = new double[inputs.size()][]; for (int i = 0; i < inputs.size(); i++) { inputsArray[i] = inputs.get(i); outputsArray[i] = new double[] { output }; } final INeuralNetwork net = SkatNetworks.getNetwork(knowledge.getGameAnnouncement().getGameType(), isDeclarer(), 0); final double networkError = net.adjustWeightsBatch(inputsArray, outputsArray); log.warn("learning error: " + networkError); lastAvgNetworkError = networkError; } } private static int chooseRandomCard(final CardList possibleCards, final CardList goodCards) { int bestCardIndex; final Card choosenCard = goodCards.get(RANDOM.nextInt(goodCards.size())); bestCardIndex = possibleCards.indexOf(choosenCard); return bestCardIndex; } private List<GameType> filterFeasibleGameTypes(final int bidValue) { // FIXME (jansch 14.09.2011) consider hand and ouvert games // return game announcement instead final List<GameType> result = new ArrayList<GameType>(); final SkatGameData data = getGameDataForWonGame(); for (final GameType gameType : feasibleGameTypes) { final GameAnnouncementFactory factory = GameAnnouncement.getFactory(); factory.setGameType(gameType); data.setAnnouncement(factory.getAnnouncement()); final SkatRule skatRules = SkatRuleFactory.getSkatRules(gameType); final int currGameResult = skatRules.calcGameResult(data); if (currGameResult >= bidValue) { result.add(gameType); } } return result; } private SkatGameData getGameDataForWonGame() { final SkatGameData data = new SkatGameData(); // it doesn't matter which position is set for declarer // skat game data are only used to calculate the game value data.setDeclarer(Player.FOREHAND); data.addDealtCards(Player.FOREHAND, knowledge.getOwnCards()); data.addSkatToPlayer(Player.FOREHAND); final SkatGameResult result = new SkatGameResult(); result.setWon(true); data.setResult(result); return data; } private boolean isAnyGamePossible(final int bidValue) { final List<GameType> filteredGameTypes = filterFeasibleGameTypes(bidValue); log.warn("Game simulation on bidding: bid value " + bidValue); log.warn("Player position: " + knowledge.getPlayerPosition() + " cards: " + knowledge.getOwnCards()); gameSimulator2.reset(); for (final GameType gameType : filteredGameTypes) { gameSimulator2.add(new GameSimulation(gameType, knowledge.getPlayerPosition(), knowledge.getOwnCards())); } final GameSimulation bestSimulation = gameSimulator2.simulateMaxEpisodes(MAX_SIMULATIONS_BIDDING); log.warn("Simulated " + bestSimulation.getEpisodes() + " episodes with highest won rate of " + bestSimulation.getWonRate()); if (bestSimulation.getWonRate() >= MIN_WON_RATE_FOR_BIDDING) { log.warn("Min won rate reached. Bidding..."); return true; } log.warn("Min won rate not reached. Passing..."); return false; } private void storeInputParameters(final double[] inputParameters) { allInputs.add(inputParameters); } }