/** * Copyright (C) 2001-2017 by RapidMiner and the contributors * * Complete list of developers available at our web site: * * http://rapidminer.com * * This program is free software: you can redistribute it and/or modify it under the terms of the * GNU Affero General Public License as published by the Free Software Foundation, either version 3 * of the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License along with this program. * If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.gui.viewer; import com.rapidminer.tools.Tools; import javax.swing.table.AbstractTableModel; /** * The model for the {@link com.rapidminer.gui.viewer.ConfusionMatrixViewerTable}. * * @author Ingo Mierswa */ public class ConfusionMatrixViewerTableModel extends AbstractTableModel { private static final long serialVersionUID = 1206988933244249851L; private String[] classNames; private double[][] counter; private double[] rowSums; private double[] columnSums; public ConfusionMatrixViewerTableModel(String[] classNames, double[][] counter) { this.classNames = classNames; this.counter = counter; this.rowSums = new double[classNames.length]; this.columnSums = new double[classNames.length]; for (int i = 0; i < classNames.length; i++) { for (int j = 0; j < classNames.length; j++) { this.columnSums[i] += counter[i][j]; this.rowSums[i] += counter[j][i]; } } } @Override public int getRowCount() { return classNames.length + 2; } @Override public int getColumnCount() { return classNames.length + 2; } @Override public Object getValueAt(int row, int col) { if (row == 0) { if (col == 0) { return ""; } else if (col == getColumnCount() - 1) { return "class precision"; } else { return "true " + classNames[col - 1]; } } else if (row == getRowCount() - 1) { if (col == 0) { return "class recall"; } else if (col == getColumnCount() - 1) { return ""; } else { double recall = counter[col - 1][col - 1] / columnSums[col - 1]; if (Double.isNaN(recall)) { return Tools.formatPercent(0); } else { return Tools.formatPercent(recall); } } } else { if (col == 0) { if (row - 1 >= 0) { return "pred. " + classNames[row - 1]; } else { return ""; } } else if (col == getColumnCount() - 1) { double precision = counter[row - 1][row - 1] / rowSums[row - 1]; if (Double.isNaN(precision)) { return Tools.formatPercent(0); } else { return Tools.formatPercent(precision); } } else { if ((col - 1 >= 0) && (row - 1 >= 0)) { return Tools.formatIntegerIfPossible(counter[col - 1][row - 1]); } else { return ""; } } } } }