/*
* This file is part of ELKI:
* Environment for Developing KDD-Applications Supported by Index-Structures
*
* Copyright (C) 2017
* ELKI Development Team
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package de.lmu.ifi.dbs.elki.evaluation.classification;
import java.text.NumberFormat;
import java.util.ArrayList;
import de.lmu.ifi.dbs.elki.data.ClassLabel;
/**
* Provides a confusion matrix with some prediction performance measures that
* can be derived from a confusion matrix.
*
* @author Arthur Zimek
* @since 0.7.0
*/
public class ConfusionMatrix {
/**
* Holds the confusion matrix. Must be a square matrix. The rows (first index)
* give the values which classes are classified as row, the columns (second
* index) give the values the class col has been classified as. Thus,
* <code>confusion[predicted][real]</code> addresses the number of instances
* of class <i>real</i> that have been assigned to the class <i>predicted</i>.
*/
private int[][] confusion;
/**
* Holds the class labels.
*/
private ArrayList<ClassLabel> labels;
/**
* Provides a confusion matrix for the given values.
*
* @param labels the class labels - must conform the confusion matrix in
* length
* @param confusion the confusion matrix. Must be a square matrix. The rows
* (first index) give the values which classes are classified as row,
* the columns (second index) give the values the class col has been
* classified as. Thus, <code>confusion[predicted][real]</code>
* addresses the number of instances of class <i>real</i> that have
* been assigned to the class <i>predicted</i>.
* @throws IllegalArgumentException if the confusion matrix is not square or
* not complete or if the length of class labels does not conform the
* length of the confusion matrix
*/
public ConfusionMatrix(ArrayList<ClassLabel> labels, int[][] confusion) throws IllegalArgumentException {
for(int i = 0; i < confusion.length; i++) {
if(confusion.length != confusion[i].length) {
throw new IllegalArgumentException("Confusion matrix irregular: row-dimension = " + confusion.length + ", col-dimension in col" + i + " = " + confusion[i].length);
}
}
if(confusion.length != labels.size()) {
throw new IllegalArgumentException("Number of class labels does not match row dimension of confusion matrix.");
}
this.confusion = confusion;
this.labels = labels;
}
/**
* Provides the <i>true positive rate</i>. Aka <i>accuracy</i> or
* <i>sensitivity</i> or <i>recall</i>: <code>TP / (TP+FN)</code>.
*
*
* @return the true positive rate
*/
public double truePositiveRate() {
return ((double) truePositives()) / (double) totalInstances();
}
/**
* Provides the <i>false positive rate</i> for the specified class.
*
*
* @param classindex the index of the class to retrieve the false positive
* rate for
* @return the false positive rate for the specified class
*/
public double falsePositiveRate(int classindex) {
int fp = falsePositives(classindex);
int tn = trueNegatives(classindex);
return ((double) fp) / ((double) (fp + tn));
}
/**
* Provides the <i>false positive rate</i>. Aka <i>false alarm rate</i>:
* <code>FP / (FP+TN)</code>.
*
*
* @return the false positive rate
*/
public double falsePositiveRate() {
double fpr = 0;
for(int i = 0; i < confusion.length; i++) {
fpr += falsePositiveRate(i) * colSum(i);
}
return fpr / totalInstances();
}
/**
* Provides the <i>positive predicted value</i> for the specified class.
*
*
* @param classindex the index of the class to retrieve the positive predicted
* value for
* @return the positive predicted value for the specified class
*/
public double positivePredictedValue(int classindex) {
int tp = truePositives(classindex);
return (double) tp / ((double) (tp + falsePositives(classindex)));
}
/**
* Provides the <i>positive predicted value</i>. Aka <i>precision</i> or
* <i>specificity</i>: <code>TP / (TP+FP)</code>.
*
* @return the positive predicted value
*/
public double positivePredictedValue() {
double ppv = 0;
for(int i = 0; i < confusion.length; i++) {
ppv += positivePredictedValue(i) * colSum(i);
}
return ppv / totalInstances();
}
/**
* The number of correctly classified instances.
*
*
* @return the number of correctly classified instances
*/
public int truePositives() {
int tp = 0;
for(int i = 0; i < confusion.length; i++) {
tp += truePositives(i);
}
return tp;
}
/**
* The number of correctly classified instances belonging to the specified
* class.
*
*
* @param classindex the index of the class to retrieve the correctly
* classified instances of
* @return the number of correctly classified instances belonging to the
* specified class
*/
public int truePositives(int classindex) {
return confusion[classindex][classindex];
}
/**
* Provides the <i>true positive rate</i> for the specified class.
*
* @param classindex the index of the class to retrieve the true positive rate
* for
* @return the true positive rate
*/
public double truePositiveRate(int classindex) {
int tp = truePositives(classindex);
return (double) tp / ((double) (tp + falseNegatives(classindex)));
}
/**
* The number of true negatives of the specified class.
*
*
* @param classindex the index of the class to retrieve the true negatives for
* @return the number of true negatives of the specified class
*/
public int trueNegatives(int classindex) {
int tn = 0;
for(int i = 0; i < confusion.length; i++) {
for(int j = 0; j < confusion[i].length; j++) {
if(i != classindex && j != classindex) {
tn += confusion[i][j];
}
}
}
return tn;
}
/**
* The false positives for the specified class.
*
*
* @param classindex the index of the class to retrieve the false positives
* for
* @return the false positives for the specified class
*/
public int falsePositives(int classindex) {
int fp = 0;
for(int i = 0; i < confusion[classindex].length; i++) {
if(i != classindex) {
fp += confusion[classindex][i];
}
}
return fp;
}
/**
* The false negatives for the specified class.
*
*
* @param classindex the index of the class to retrieve the false negatives
* for
* @return the false negatives for the specified class
*/
public int falseNegatives(int classindex) {
int fn = 0;
for(int i = 0; i < confusion.length; i++) {
if(i != classindex) {
fn += confusion[i][classindex];
}
}
return fn;
}
/**
* The total number of instances covered by this confusion matrix.
*
*
* @return the total number of instances covered by this confusion matrix
*/
public int totalInstances() {
int total = 0;
for(int i = 0; i < confusion.length; i++) {
for(int j = 0; j < confusion[i].length; j++) {
total += confusion[i][j];
}
}
return total;
}
/**
* The number of instances present in the specified row. I.e., classified as
* class <code>classindex</code>.
*
*
* @param classindex the index of the class the resulting number of instances
* has been classified as
* @return the number of instances present in the specified row
*/
public int rowSum(int classindex) {
int s = 0;
for(int i = 0; i < confusion[classindex].length; i++) {
s += confusion[classindex][i];
}
return s;
}
/**
* The number of instances present in the specified column. I.e., the
* instances of class <code>classindex</code>.
*
*
* @param classindex the index of the class theresulting number of instances
* belongs to
* @return the number of instances present in the specified column
*/
public int colSum(int classindex) {
int s = 0;
for(int i = 0; i < confusion.length; i++) {
s += confusion[i][classindex];
}
return s;
}
/**
* The number of instances belonging to class <code>trueClassindex</code> and
* predicted as <code>predictedClassindex</code>.
*
*
* @param trueClassindex the true class index
* @param predictedClassindex the predicted class index
* @return the number of instances belonging to class
* <code>trueClassindex</code> and predicted as
* <code>predictedClassindex</code>
*/
public int value(int trueClassindex, int predictedClassindex) {
return confusion[predictedClassindex][trueClassindex];
}
/**
* Provides a String representation of this confusion matrix.
*
* @see java.lang.Object#toString()
*/
@Override
public String toString() {
int max = 0;
for(int i = 0; i < confusion.length; i++) {
for(int j = 0; j < confusion[i].length; j++) {
if(confusion[i][j] > max) {
max = confusion[i][j];
}
}
}
String classPrefix = "C_";
NumberFormat nf = NumberFormat.getInstance();
nf.setParseIntegerOnly(true);
int labelLength = Integer.toString(labels.size()).length();
nf.setMaximumIntegerDigits(labelLength);
nf.setMinimumIntegerDigits(labelLength);
int cell = Math.max(Integer.toString(max).length(), labelLength + classPrefix.length());
String separator = " ";
StringBuilder representation = new StringBuilder();
for(int i = 1; i <= labels.size(); i++) {
representation.append(separator);
String label = classPrefix + nf.format(i);
int space = cell - labelLength - classPrefix.length();
for(int s = 0; s <= space; s++) {
representation.append(' ');
}
representation.append(label);
}
representation.append('\n');
for(int row = 0; row < confusion.length; row++) {
for(int col = 0; col < confusion[row].length; col++) {
representation.append(separator);
String entry = Integer.toString(confusion[row][col]);
int space = cell - entry.length();
for(int s = 0; s <= space; s++) {
representation.append(' ');
}
representation.append(entry);
}
representation.append(separator);
representation.append(classPrefix);
representation.append(nf.format(row + 1));
representation.append(": ");
representation.append(labels.get(row));
representation.append('\n');
}
return representation.toString();
}
}