/** * Copyright (C) 2017 Jan Schäfer (jansch@users.sourceforge.net) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.jskat.ai.nn.data; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import org.jskat.ai.nn.input.GenericNetworkInputGenerator; import org.jskat.ai.nn.util.EncogNetworkWrapper; import org.jskat.ai.nn.util.INeuralNetwork; import org.jskat.ai.nn.util.NetworkTopology; import org.jskat.util.GameType; /** * Holds all neural networks for the NN player. */ public final class SkatNetworks { private static int INPUT_NEURONS = GenericNetworkInputGenerator.getNeuronCountForAllStrategies(); private static int OUTPUT_NEURONS = 1; private static int HIDDEN_NEURONS = (INPUT_NEURONS + OUTPUT_NEURONS) * 2; private static final boolean USE_BIAS = true; private final static SkatNetworks INSTANCE = new SkatNetworks(); private static Map<GameType, Map<PlayerParty, List<INeuralNetwork>>> networks; /** * Private constructor for singleton class. */ private SkatNetworks() { createNetworks(); // loadNetworks(); } /** * Gets a neural network. * * @param gameType * Game type * @param isDeclarer * TRUE, if declarer network is desired * @param trickNoInGame * Trick number in current game * @return Neural network */ public static INeuralNetwork getNetwork(GameType gameType, boolean isDeclarer, int trickNoInGame) { Map<PlayerParty, List<INeuralNetwork>> gameTypeNets = networks.get(gameType); List<INeuralNetwork> playerPartyNets = null; if (GameType.RAMSCH.equals(gameType) || isDeclarer) { playerPartyNets = gameTypeNets.get(PlayerParty.DECLARER); } else { playerPartyNets = gameTypeNets.get(PlayerParty.OPPONENT); } return playerPartyNets.get(trickNoInGame); } /** * Gets an instance of the SkatNetworks. * * @return Instance */ public static SkatNetworks instance() { return INSTANCE; } /** * Loads all neural networks from files. */ public static void loadNetworks() { for (Entry<GameType, Map<PlayerParty, List<INeuralNetwork>>> gameTypeNets : networks.entrySet()) { for (Entry<PlayerParty, List<INeuralNetwork>> playerPartyNet : gameTypeNets.getValue().entrySet()) { for (int trick = 0; trick < 10; trick++) { playerPartyNet.getValue().get(trick) .loadNetwork("/org/jskat/ai/nn/data/jskat".concat("." + gameTypeNets.getKey()) .concat("." + playerPartyNet.getKey()).concat(".TRICK" + trick).concat(".nnet"), INPUT_NEURONS, HIDDEN_NEURONS, OUTPUT_NEURONS); } } } } /** * Resets neural networks */ public static void resetNeuralNetworks() { createNetworks(); } /** * Saves all networks to files * * @param savePath * Path to files */ public static void saveNetworks(final String savePath) { for (GameType gameType : GameType.values()) { saveNetworks(savePath, gameType); } } /** * Saves all networks of the specified game type to files * * @param savePath * Path to files * @param gameType * Game type */ public static void saveNetworks(final String savePath, GameType gameType) { Map<PlayerParty, List<INeuralNetwork>> gameTypeNetworks = networks.get(gameType); for (Entry<PlayerParty, List<INeuralNetwork>> playerPartyNets : gameTypeNetworks.entrySet()) { for (int trick = 0; trick < 10; trick++) { playerPartyNets.getValue().get(trick).saveNetwork(savePath.concat("jskat").concat("." + gameType) .concat("." + playerPartyNets.getKey()).concat(".TRICK" + trick).concat(".nnet")); } } } private static void createNetworks() { int[] hiddenLayer = { HIDDEN_NEURONS }; NetworkTopology topo = new NetworkTopology(INPUT_NEURONS, hiddenLayer, OUTPUT_NEURONS); networks = new HashMap<GameType, Map<PlayerParty, List<INeuralNetwork>>>(); for (GameType gameType : GameType.values()) { networks.put(gameType, new HashMap<PlayerParty, List<INeuralNetwork>>()); for (PlayerParty playerParty : PlayerParty.values()) { List<INeuralNetwork> networkList = new ArrayList<>(); networks.get(gameType).put(playerParty, networkList); EncogNetworkWrapper network = new EncogNetworkWrapper(topo, USE_BIAS); for (int trick = 0; trick < 10; trick++) { networkList.add(network); } } } } }