/**
* Copyright (C) 2017 Jan Schäfer (jansch@users.sourceforge.net)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jskat.ai.nn.train;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.jskat.ai.nn.AIPlayerNN;
import org.jskat.control.JSkatEventBus;
import org.jskat.control.JSkatMaster;
import org.jskat.control.JSkatThread;
import org.jskat.control.SkatGame;
import org.jskat.control.command.table.CreateTableCommand;
import org.jskat.control.event.nntraining.TrainingResultEvent;
import org.jskat.data.GameAnnouncement;
import org.jskat.data.GameAnnouncement.GameAnnouncementFactory;
import org.jskat.data.GameSummary;
import org.jskat.data.JSkatViewType;
import org.jskat.data.SkatGameData.GameState;
import org.jskat.gui.NullView;
import org.jskat.player.JSkatPlayer;
import org.jskat.util.CardDeck;
import org.jskat.util.GameType;
import org.jskat.util.GameVariant;
import org.jskat.util.Player;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.helpers.NOPLogger;
/**
* Trains the neural networks.
*/
public class NNTrainer extends JSkatThread {
private static Logger log = LoggerFactory.getLogger(NNTrainer.class);
private static final Integer MAX_TRAINING_EPISODES = 1000000;
private static final Integer MAX_TRAINING_EPISODES_WITHOUT_SAVE = 100000;
static final String NEURAL_NETWORK_PLAYER_CLASS = "org.jskat.ai.nn.AIPlayerNN";
static final String RANDOM_PLAYER_CLASS = "org.jskat.ai.rnd.AIPlayerRND";
private static final List<String> playerTypes = new ArrayList<String>();
static {
playerTypes.add(NEURAL_NETWORK_PLAYER_CLASS);
}
private GameType gameType;
private boolean stopTraining = false;
/**
* @see java.lang.Thread#run()
*/
@Override
public void run() {
trainNets();
}
/**
* Sets the game type to learn
*
* @param newGameType
* Game type
*/
public void setGameType(final GameType newGameType) {
this.gameType = newGameType;
setName("NNTrainer for " + this.gameType); //$NON-NLS-1$
}
/**
* Stops the training
*
* @param isStopTraining
* TRUE, if the training should be stopped
*/
public void stopTraining(boolean isStopTraining) {
this.stopTraining = isStopTraining;
}
private JSkatPlayer createPlayer(String playerType) {
JSkatPlayer player = JSkatMaster.INSTANCE.createPlayer(playerType);
if (NEURAL_NETWORK_PLAYER_CLASS.equals(playerType)) {
AIPlayerNN nnPlayer = (AIPlayerNN) player;
nnPlayer.setIsLearning(true);
nnPlayer.setLogger(NOPLogger.NOP_LOGGER);
}
return player;
}
private boolean isGameWon(final Player currPlayer, final SkatGame game) {
// FIXME (jansch 28.06.2011) have to call getGameResult() to get
// the result
game.getGameResult();
boolean gameWon = false;
if (this.gameType.equals(GameType.RAMSCH)) {
gameWon = isRamschGameWon(game.getGameSummary(), currPlayer);
} else {
gameWon = game.isGameWon();
}
return gameWon;
}
private SkatGame prepareGame(final JSkatPlayer player1, final JSkatPlayer player2, final JSkatPlayer player3,
final Player declarer, final CardDeck cardDeck) {
player1.newGame(Player.FOREHAND);
player2.newGame(Player.MIDDLEHAND);
player3.newGame(Player.REARHAND);
JSkatEventBus.INSTANCE.post(new CreateTableCommand(JSkatViewType.TRAINING_TABLE, "TRAIN" + gameType.name()));
SkatGame game = new SkatGame("TRAIN" + gameType.name(), GameVariant.STANDARD, player1, player2, player3);
game.setView(new NullView());
game.setLogger(NOPLogger.NOP_LOGGER);
if (cardDeck != null) {
game.setCardDeck(cardDeck);
} else {
CardDeck newCardDeck = new CardDeck();
newCardDeck.shuffle();
log.debug("Card deck: " + newCardDeck); //$NON-NLS-1$
game.setCardDeck(newCardDeck);
}
game.dealCards();
if (!GameType.RAMSCH.equals(this.gameType)) {
game.setDeclarer(declarer);
}
GameAnnouncementFactory factory = GameAnnouncement.getFactory();
factory.setGameType(this.gameType);
GameAnnouncement announcement = factory.getAnnouncement();
game.setGameAnnouncement(announcement);
player1.startGame(declarer, announcement);
player2.startGame(declarer, announcement);
player3.startGame(declarer, announcement);
game.setGameState(GameState.TRICK_PLAYING);
return game;
}
private void runGame(final SkatGame game) {
game.start();
try {
game.join();
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
/**
* Trains the neural networks
*/
private void trainNets() {
long totalGames = 0;
long totalWonGames = 0;
double declarerAvgNetworkErrorSum = 0.0;
long declarerParticipations = 0;
double opponentAvgNetworkErrorSum = 0.0;
long opponentParticipations = 0;
Set<List<String>> playerPermutations = createPlayerPermutations(playerTypes);
while (!this.stopTraining /* && totalGames < MAX_TRAINING_EPISODES */) {
if (totalGames > 0) {
if (totalGames % MAX_TRAINING_EPISODES_WITHOUT_SAVE == 0) {
JSkatMaster.INSTANCE.saveNeuralNetworks(this.gameType);
}
if (opponentParticipations == 0) {
// for ramsch games
JSkatEventBus.INSTANCE.post(new TrainingResultEvent(this.gameType, totalGames, totalWonGames,
declarerAvgNetworkErrorSum / declarerParticipations, 0.0));
} else {
JSkatEventBus.INSTANCE.post(new TrainingResultEvent(this.gameType, totalGames, totalWonGames,
declarerAvgNetworkErrorSum / declarerParticipations,
opponentAvgNetworkErrorSum / opponentParticipations));
}
}
for (List<String> playerConstellation : playerPermutations) {
for (Player declarer : Player.values()) {
JSkatPlayer player1 = createPlayer(playerConstellation.get(0));
JSkatPlayer player2 = createPlayer(playerConstellation.get(1));
JSkatPlayer player3 = createPlayer(playerConstellation.get(2));
SkatGame game = prepareGame(player1, player2, player3, declarer, null);
// SkatGame game = prepareGame(player1, player2, player3,
// Player.FOREHAND, getPerfectDistribution());
runGame(game);
if (isGameWon(declarer, game)) {
log.debug("Game won.");
totalWonGames++;
} else {
log.debug("Game lost.");
}
if (player1 instanceof AIPlayerNN) {
if (player1.isDeclarer()) {
declarerAvgNetworkErrorSum += ((AIPlayerNN) player1).getLastAvgNetworkError();
declarerParticipations++;
} else {
opponentAvgNetworkErrorSum += ((AIPlayerNN) player1).getLastAvgNetworkError();
opponentParticipations++;
}
}
if (player2 instanceof AIPlayerNN) {
if (player2.isDeclarer()) {
declarerAvgNetworkErrorSum += ((AIPlayerNN) player2).getLastAvgNetworkError();
declarerParticipations++;
} else {
opponentAvgNetworkErrorSum += ((AIPlayerNN) player2).getLastAvgNetworkError();
opponentParticipations++;
}
}
if (player3 instanceof AIPlayerNN) {
if (player3.isDeclarer()) {
declarerAvgNetworkErrorSum += ((AIPlayerNN) player3).getLastAvgNetworkError();
declarerParticipations++;
} else {
opponentAvgNetworkErrorSum += ((AIPlayerNN) player3).getLastAvgNetworkError();
opponentParticipations++;
}
}
totalGames++;
}
}
checkWaitCondition();
}
}
static Set<List<String>> createPlayerPermutations(List<String> playerTypes) {
Set<List<String>> result = new HashSet<List<String>>();
for (String player1 : playerTypes) {
for (String player2 : playerTypes) {
for (String player3 : playerTypes) {
if (NEURAL_NETWORK_PLAYER_CLASS.equals(player1) || NEURAL_NETWORK_PLAYER_CLASS.equals(player2)
|| NEURAL_NETWORK_PLAYER_CLASS.equals(player3)) {
List<String> playerPermutation = new ArrayList<String>();
playerPermutation.add(player1);
playerPermutation.add(player2);
playerPermutation.add(player3);
result.add(playerPermutation);
}
}
}
}
return result;
}
// FIXME (jan 10.03.2012) code duplication with AIPlayerNN
private static boolean isRamschGameWon(final GameSummary gameSummary, final Player currPlayer) {
boolean ramschGameWon = false;
int playerPoints = gameSummary.getPlayerPoints(currPlayer);
int highestPlayerPoints = 0;
for (Player player : Player.values()) {
int currPlayerPoints = gameSummary.getPlayerPoints(player);
if (currPlayerPoints > highestPlayerPoints) {
highestPlayerPoints = currPlayerPoints;
}
}
if (highestPlayerPoints > playerPoints) {
ramschGameWon = true;
}
return ramschGameWon;
}
}