/**
* This file is part of JSkat.
*
* JSkat is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JSkat is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with JSkat. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jskat.gui.swing.nn;
import java.awt.Component;
import java.awt.Container;
import java.awt.Dimension;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import javax.swing.JDialog;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JTable;
import javax.swing.SwingConstants;
import javax.swing.table.AbstractTableModel;
import javax.swing.table.DefaultTableCellRenderer;
import javax.swing.table.TableModel;
import org.jskat.gui.swing.LayoutFactory;
import org.jskat.util.GameType;
/**
* Overview dialog for training of neural networks
*/
public class NeuralNetworkTrainingOverview extends JDialog {
private static final long serialVersionUID = 1L;
private final Component parent;
JTable overviewTable;
/**
* Constructor
*
* @param parent
* Parent component of the dialog
*/
public NeuralNetworkTrainingOverview(final Component parent) {
this.parent = parent;
initGUI();
}
private void initGUI() {
setMinimumSize(new Dimension(600, 480));
setTitle("Training of neural networks");
Container root = getContentPane();
root.setLayout(LayoutFactory.getMigLayout("fill", "fill", "fill"));
JPanel rootPanel = new JPanel(LayoutFactory.getMigLayout("fill",
"fill", "fill"));
overviewTable = new JTable(new TrainingOverviewTableModel());
overviewTable.getColumnModel().getColumn(3)
.setCellRenderer(new DoubleRenderer(5));
overviewTable.getColumnModel().getColumn(4)
.setCellRenderer(new DoubleRenderer(10));
overviewTable.getColumnModel().getColumn(5)
.setCellRenderer(new DoubleRenderer(10));
overviewTable.setAutoResizeMode(JTable.AUTO_RESIZE_LAST_COLUMN);
JScrollPane scrollPane = new JScrollPane(overviewTable);
rootPanel.add(scrollPane, "grow, center");
root.add(rootPanel, "center, grow");
}
private class DoubleRenderer extends DefaultTableCellRenderer {
private DecimalFormat formatter;
private final int fractionDigits;
public DoubleRenderer(int fractionDigits) {
super();
setHorizontalAlignment(SwingConstants.RIGHT);
this.fractionDigits = fractionDigits;
}
@Override
public void setValue(Object value) {
if (formatter == null) {
formatter = new DecimalFormat("###,###,###.##");
formatter.setMinimumFractionDigits(fractionDigits);
formatter.setMaximumFractionDigits(fractionDigits);
DecimalFormatSymbols dfs = formatter.getDecimalFormatSymbols();
dfs.setGroupingSeparator(' ');
dfs.setDecimalSeparator(',');
formatter.setDecimalFormatSymbols(dfs);
}
setText((value == null) ? "" : formatter.format(value));
}
}
/**
* Adds training result
*
* @param gameType
* Game type of neural net
* @param episodes
* Number of episodes
* @param totalWonGames
* Total Number of won games
* @param avgNetworkErrorDeclarer
* Average difference of declarer network
* @param avgNetworkErrorOpponents
* Average difference of opponents networks
*/
public void addTrainingResult(GameType gameType, Long episodes,
Long totalWonGames, Double avgNetworkErrorDeclarer,
Double avgNetworkErrorOpponents) {
((TrainingOverviewTableModel) overviewTable.getModel())
.addTrainingResult(gameType, episodes, totalWonGames,
avgNetworkErrorDeclarer, avgNetworkErrorOpponents);
}
private class TrainingOverviewTableModel extends AbstractTableModel {
private static final long serialVersionUID = 1L;
private final List<String> header;
private final HashMap<GameType, List<Object>> data;
protected TrainingOverviewTableModel() {
header = new ArrayList<>();
header.add("Game type");
header.add("Episodes");
header.add("Total won games");
header.add("Percent");
header.add("Network error declarer");
header.add("Network error opponents");
data = new HashMap<>();
for (GameType currGameType : GameType.values()) {
List<Object> list = new ArrayList<>();
list.add(currGameType);
for (int i = 1; i < getColumnCount(); i++) {
list.add(0);
}
data.put(currGameType, list);
}
}
@Override
public int getRowCount() {
return GameType.values().length;
}
@Override
public int getColumnCount() {
return 6;
}
@Override
public String getColumnName(int column) {
return header.get(column);
}
@Override
public Object getValueAt(int rowIndex, int columnIndex) {
GameType gameType = GameType.values()[rowIndex];
return data.get(gameType).get(columnIndex);
}
/**
* Adds training result
*
* @param gameType
* Game type of neural net
* @param episodes
* Number of episodes
* @param totalWonGames
* Total number of won games
* @param avgNetworkErrorDeclarer
* Average error of declarer network
* @param avgNetworkErrorOpponents
* Average error of opponents networks
*/
public void addTrainingResult(GameType gameType, Long episodes,
Long totalWonGames, Double avgNetworkErrorDeclarer,
Double avgNetworkErrorOpponents) {
TableModel tableModel = overviewTable.getModel();
tableModel.setValueAt(gameType, gameType.ordinal(), 0);
tableModel.setValueAt(episodes, gameType.ordinal(), 1);
tableModel.setValueAt(totalWonGames, gameType.ordinal(), 2);
tableModel.setValueAt(totalWonGames * 100.0d / episodes,
gameType.ordinal(), 3);
tableModel.setValueAt(avgNetworkErrorDeclarer, gameType.ordinal(),
4);
tableModel.setValueAt(avgNetworkErrorOpponents, gameType.ordinal(),
5);
fireTableDataChanged();
}
@Override
public void setValueAt(Object value, int rowIndex, int columnIndex) {
data.get(GameType.values()[rowIndex]).set(columnIndex, value);
}
}
/**
* @see JDialog#setVisible(boolean)
*/
@Override
public void setVisible(final boolean isVisible) {
if (isVisible) {
setLocationRelativeTo(parent);
}
super.setVisible(isVisible);
}
}