/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.eval;
import rapaio.data.Var;
import rapaio.printer.Printable;
import rapaio.printer.format.TextTable;
import rapaio.sys.WS;
import java.util.Arrays;
import java.util.stream.IntStream;
import static rapaio.sys.WS.formatFlex;
/**
* Confusion matrix utility.
* <p>
* User: <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a>
*/
public class Confusion implements Printable {
private final Var actual;
private final Var predict;
private final String[] factors;
private final int[][] cmf;
private final boolean binary;
private final boolean percents;
private double acc;
private double mcc;
private double f1;
private double g;
private double precision;
private double recall;
private double completeCases = 0;
private double acceptedCases = 0;
private double errorCases = 0;
public Confusion(Var actual, Var predict) {
this(actual, predict, false);
}
public Confusion(Var actual, Var predict, boolean percents) {
validate(actual, predict);
this.actual = actual;
this.predict = predict;
this.factors = actual.levels();
this.cmf = new int[factors.length - 1][factors.length - 1];
this.percents = percents;
this.binary = actual.levels().length == 3;
compute();
}
private void validate(Var actual, Var predict) {
if (!actual.type().isNominal()) {
throw new IllegalArgumentException("actual values var must be nominal");
}
if (!predict.type().isNominal()) {
throw new IllegalArgumentException("fit values var must be nominal");
}
if (actual.levels().length != predict.levels().length) {
throw new IllegalArgumentException("actual and fit does not have the same nominal levels");
}
for (int i = 0; i < actual.levels().length; i++) {
if (!actual.levels()[i].equals(predict.levels()[i])) {
throw new IllegalArgumentException(
String.format("not the same nominal levels (actual:%s, fit:%s)",
Arrays.deepToString(actual.levels()),
Arrays.deepToString(predict.levels())));
}
}
}
private void compute() {
for (int i = 0; i < actual.rowCount(); i++) {
if (actual.index(i) != 0 && predict.index(i) != 0) {
completeCases++;
cmf[actual.index(i) - 1][predict.index(i) - 1]++;
}
}
acc = IntStream.range(0, cmf.length).mapToDouble(i -> cmf[i][i]).sum();
acceptedCases = acc;
errorCases = completeCases - acceptedCases;
if (completeCases == 0) {
acc = 0;
} else {
acc = acc / completeCases;
}
if (binary) {
double tp = cmf[0][0];
double tn = cmf[1][1];
double fp = cmf[1][0];
double fn = cmf[0][1];
mcc = (tp * tn - fp * fn) / Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));
f1 = 2 * tp / (2 * tp + fp + fn);
precision = tp / (tp + fp);
recall = tp / (tp + fn);
g = Math.sqrt(precision * recall);
}
}
@Override
public String summary() {
StringBuilder sb = new StringBuilder();
addConfusionMatrix(sb);
addDetails(sb);
return sb.toString();
}
private void addDetails(StringBuilder sb) {
sb.append(String.format("\nComplete cases %d from %d\n", (int) Math.rint(completeCases), actual.rowCount()));
sb.append(String.format("Acc: %s (Accuracy )\n", formatFlex(acc)));
if (binary) {
sb.append(String.format("F1: %s (F1 score / F-measure)\n", formatFlex(f1)));
sb.append(String.format("MCC: %s (Matthew correlation coefficient)\n", formatFlex(mcc)));
sb.append(String.format("Pre: %s (Precision)\n", formatFlex(precision)));
sb.append(String.format("Rec: %s (Recall)\n", formatFlex(recall)));
sb.append(String.format("G: %s (G-measure)\n", formatFlex(g)));
}
}
private void addConfusionMatrix(StringBuilder sb) {
sb.append("> Confusion\n");
sb.append("\n");
TextTable tt = TextTable.newEmpty(factors.length + 3, factors.length + 3);
tt.withSplit();
tt.set(0, 0, "Ac\\Pr", 0);
for (int i = 0; i < factors.length - 1; i++) {
tt.set(i + 2, 0, factors[i + 1], 1);
tt.set(i + 2, 1, "|", 0);
tt.set(i + 2, factors.length + 1, "|", 0);
tt.set(0, i + 2, factors[i + 1], 1);
tt.set(1, i + 2, line(factors[i + 1].length()), 1);
tt.set(factors.length + 1, i + 2, line(factors[i + 1].length()), 1);
}
tt.set(factors.length + 2, 0, "total", 1);
tt.set(0, factors.length + 2, "total", 1);
tt.set(1, 0, line("Ac\\Pr".length()), 0);
tt.set(factors.length + 1, 0, line("Ac\\Pr".length()), 0);
tt.set(1, factors.length + 2, line("Ac\\Pr".length()), 0);
tt.set(factors.length + 1, factors.length + 2, line("Ac\\Pr".length()), 0);
tt.set(0, 1, "|", 0);
tt.set(1, 1, "|", 0);
tt.set(factors.length + 1, 1, "|", 0);
tt.set(factors.length + 2, 1, "|", 0);
tt.set(0, factors.length + 1, "|", 0);
tt.set(1, factors.length + 1, "|", 0);
tt.set(factors.length + 1, factors.length + 1, "|", 0);
tt.set(factors.length + 2, factors.length + 1, "|", 0);
int[] rowTotals = new int[factors.length - 1];
int[] colTotals = new int[factors.length - 1];
int grandTotal = 0;
for (int i = 0; i < factors.length - 1; i++) {
for (int j = 0; j < factors.length - 1; j++) {
tt.set(i + 2, j + 2, ((i == j) ? ">" : " ") + cmf[i][j], 1);
grandTotal += cmf[i][j];
rowTotals[i] += cmf[i][j];
colTotals[j] += cmf[i][j];
}
}
for (int i = 0; i < factors.length - 1; i++) {
tt.set(factors.length + 2, i + 2, String.valueOf(colTotals[i]), 1);
tt.set(i + 2, factors.length + 2, String.valueOf(rowTotals[i]), 1);
}
tt.set(factors.length + 2, factors.length + 2, String.valueOf(grandTotal), 1);
sb.append(tt.summary());
if (percents && completeCases > 0.) {
tt = TextTable.newEmpty(factors.length + 3, factors.length + 3);
tt.withSplit();
tt.set(0, 0, "Ac\\Pr", 0);
for (int i = 0; i < factors.length - 1; i++) {
tt.set(i + 2, 0, factors[i + 1], 1);
tt.set(i + 2, 1, "|", 0);
tt.set(i + 2, factors.length + 1, "|", 0);
tt.set(0, i + 2, factors[i + 1], 1);
tt.set(1, i + 2, line(factors[i + 1].length()), 1);
tt.set(factors.length + 1, i + 2, line(factors[i + 1].length()), 1);
}
tt.set(factors.length + 2, 0, "total", 1);
tt.set(0, factors.length + 2, "total", 1);
tt.set(1, 0, line("Ac\\Pr".length()), 0);
tt.set(factors.length + 1, 0, line("Ac\\Pr".length()), 0);
tt.set(1, factors.length + 2, line("Ac\\Pr".length()), 0);
tt.set(factors.length + 1, factors.length + 2, line("Ac\\Pr".length()), 0);
tt.set(0, 1, "|", 0);
tt.set(1, 1, "|", 0);
tt.set(factors.length + 1, 1, "|", 0);
tt.set(factors.length + 2, 1, "|", 0);
tt.set(0, factors.length + 1, "|", 0);
tt.set(1, factors.length + 1, "|", 0);
tt.set(factors.length + 1, factors.length + 1, "|", 0);
tt.set(factors.length + 2, factors.length + 1, "|", 0);
for (int i = 0; i < factors.length - 1; i++) {
for (int j = 0; j < factors.length - 1; j++) {
tt.set(i + 2, j + 2, ((i == j) ? ">" : " ") + WS.formatShort(cmf[i][j] / completeCases), 1);
}
}
for (int i = 0; i < factors.length - 1; i++) {
tt.set(factors.length + 2, i + 2, WS.formatShort(colTotals[i] / completeCases), 1);
tt.set(i + 2, factors.length + 2, WS.formatShort(rowTotals[i] / completeCases), 1);
}
tt.set(factors.length + 2, factors.length + 2, WS.formatShort(grandTotal / completeCases), 1);
sb.append(tt.summary());
}
}
private String line(int len) {
char[] lineChars = new char[len];
for (int i = 0; i < len; i++) {
lineChars[i] = '-';
}
return String.valueOf(lineChars);
}
public double accuracy() {
return acc;
}
public double error() {
return 1.0 - acc;
}
/**
* Number of cases which were correctly predicted
*/
public int acceptedCases() {
return (int) Math.rint(acceptedCases);
}
/**
* Number of cases which were not predicted correctly
*/
public int errorCases() {
return (int) Math.rint(errorCases);
}
public int completeCases() {
return (int) Math.rint(completeCases);
}
public int[][] matrix() {
return cmf;
}
}