/*-
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.eval;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.eval.meta.Prediction;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.accum.MatchCondition;
import org.nd4j.linalg.api.ops.impl.transforms.Not;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.*;
/**
* Evaluation metrics:
* precision, recall, f1
*
* @author Adam Gibson
*/
@Slf4j
@EqualsAndHashCode(callSuper = true)
public class Evaluation extends BaseEvaluation<Evaluation> {
protected final int topN;
protected int topNCorrectCount = 0;
protected int topNTotalCount = 0; //Could use topNCountCorrect / (double)getNumRowCounter() - except for eval(int,int), hence separate counters
protected Counter<Integer> truePositives = new Counter<>();
protected Counter<Integer> falsePositives = new Counter<>();
protected Counter<Integer> trueNegatives = new Counter<>();
protected Counter<Integer> falseNegatives = new Counter<>();
protected ConfusionMatrix<Integer> confusion;
protected int numRowCounter = 0;
@Getter
@Setter
protected List<String> labelsList = new ArrayList<>();
//What to output from the precision/recall function when we encounter an edge case
protected static final double DEFAULT_EDGE_VALUE = 0.0;
protected Map<Pair<Integer, Integer>, List<Object>> confusionMatrixMetaData; //Pair: (Actual,Predicted)
// Empty constructor
public Evaluation() {
this.topN = 1;
}
// Constructor that takes number of output classes
/**
* The number of classes to account
* for in the evaluation
* @param numClasses the number of classes to account for in the evaluation
*/
public Evaluation(int numClasses) {
this(createLabels(numClasses), 1);
}
/**
* The labels to include with the evaluation.
* This constructor can be used for
* generating labeled output rather than just
* numbers for the labels
* @param labels the labels to use
* for the output
*/
public Evaluation(List<String> labels) {
this(labels, 1);
}
/**
* Use a map to generate labels
* Pass in a label index with the actual label
* you want to use for output
* @param labels a map of label index to label value
*/
public Evaluation(Map<Integer, String> labels) {
this(createLabelsFromMap(labels), 1);
}
/**
* Constructor to use for top N accuracy
*
* @param labels Labels for the classes (may be null)
* @param topN Value to use for top N accuracy calculation (<=1: standard accuracy). Note that with top N
* accuracy, an example is considered 'correct' if the probability for the true class is one of the
* highest N values
*/
public Evaluation(List<String> labels, int topN) {
this.labelsList = labels;
if (labels != null) {
createConfusion(labels.size());
}
this.topN = topN;
}
@Override
public void reset() {
confusion = null;
truePositives = new Counter<>();
falsePositives = new Counter<>();
trueNegatives = new Counter<>();
falseNegatives = new Counter<>();
topNCorrectCount = 0;
topNTotalCount = 0;
numRowCounter = 0;
}
private ConfusionMatrix<Integer> confusion() {
if(confusion != null)
return confusion;
confusion = new ConfusionMatrix<>();
return confusion;
}
private static List<String> createLabels(int numClasses) {
if (numClasses == 1)
numClasses = 2; //Binary (single output variable) case...
List<String> list = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
list.add(String.valueOf(i));
}
return list;
}
private static List<String> createLabelsFromMap(Map<Integer, String> labels) {
int size = labels.size();
List<String> labelsList = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
String str = labels.get(i);
if (str == null)
throw new IllegalArgumentException("Invalid labels map: missing key for class " + i
+ " (expect integers 0 to " + (size - 1) + ")");
labelsList.add(str);
}
return labelsList;
}
private void createConfusion(int nClasses) {
List<Integer> classes = new ArrayList<>();
for (int i = 0; i < nClasses; i++) {
classes.add(i);
}
confusion = new ConfusionMatrix<>(classes);
}
/**
* Evaluate the output
* using the given true labels,
* the input to the multi layer network
* and the multi layer network to
* use for evaluation
* @param trueLabels the labels to ise
* @param input the input to the network to use
* for evaluation
* @param network the network to use for output
*/
public void eval(INDArray trueLabels, INDArray input, ComputationGraph network) {
eval(trueLabels, network.output(false, input)[0]);
}
/**
* Evaluate the output
* using the given true labels,
* the input to the multi layer network
* and the multi layer network to
* use for evaluation
* @param trueLabels the labels to ise
* @param input the input to the network to use
* for evaluation
* @param network the network to use for output
*/
public void eval(INDArray trueLabels, INDArray input, MultiLayerNetwork network) {
eval(trueLabels, network.output(input, Layer.TrainingMode.TEST));
}
/**
* Collects statistics on the real outcomes vs the
* guesses. This is for logistic outcome matrices.
* <p>
* Note that an IllegalArgumentException is thrown if the two passed in
* matrices aren't the same length.
*
* @param realOutcomes the real outcomes (labels - usually binary)
* @param guesses the guesses/prediction (usually a probability vector)
*/
public void eval(INDArray realOutcomes, INDArray guesses) {
eval(realOutcomes, guesses, (List<Serializable>) null);
}
/**
* Evaluate the network, with optional metadata
*
* @param realOutcomes Data labels
* @param guesses Network predictions
* @param recordMetaData Optional; may be null. If not null, should have size equal to the number of outcomes/guesses
*
*/
@Override
public void eval(final INDArray realOutcomes, final INDArray guesses, final List<? extends Serializable> recordMetaData) {
// Add the number of rows to numRowCounter
numRowCounter += realOutcomes.shape()[0];
// If confusion is null, then Evaluation was instantiated without providing the classes -> infer # classes from
if (confusion == null) {
int nClasses = realOutcomes.columns();
if (nClasses == 1)
nClasses = 2; //Binary (single output variable) case
labelsList = new ArrayList<>(nClasses);
for (int i = 0; i < nClasses; i++)
labelsList.add(String.valueOf(i));
createConfusion(nClasses);
}
// Length of real labels must be same as length of predicted labels
if (realOutcomes.length() != guesses.length())
throw new IllegalArgumentException("Unable to evaluate. Outcome matrices not same length");
// For each row get the most probable label (column) from prediction and assign as guessMax
// For each row get the column of the true label and assign as currMax
final int nCols = realOutcomes.columns();
final int nRows = realOutcomes.rows();
if (nCols == 1) {
INDArray binaryGuesses = guesses.gt(0.5);
INDArray notLabel = Nd4j.getExecutioner().execAndReturn(new Not(realOutcomes.dup()));
INDArray notGuess = Nd4j.getExecutioner().execAndReturn(new Not(binaryGuesses.dup()));
//tp: predicted = 1, actual = 1
int tp = binaryGuesses.mul(realOutcomes).sumNumber().intValue();
//fp: predicted = 1, actual = 0
int fp = binaryGuesses.mul(notLabel).sumNumber().intValue();
//fn: predicted = 0, actual = 1
int fn = notGuess.mul(realOutcomes).sumNumber().intValue();
int tn = nRows - tp - fp - fn;
confusion().add(1, 1, tp);
confusion().add(1, 0, fn);
confusion().add(0, 1, fp);
confusion().add(0, 0, tn);
truePositives.incrementCount(1, tp);
falsePositives.incrementCount(1, fp);
falseNegatives.incrementCount(1, fn);
trueNegatives.incrementCount(1, tn);
truePositives.incrementCount(0, tn);
falsePositives.incrementCount(0, fn);
falseNegatives.incrementCount(0, fp);
trueNegatives.incrementCount(0, tp);
if (recordMetaData != null) {
for (int i = 0; i < binaryGuesses.size(0); i++) {
if (i >= recordMetaData.size())
break;
int actual = realOutcomes.getDouble(0) == 0.0 ? 0 : 1;
int predicted = binaryGuesses.getDouble(0) == 0.0 ? 0 : 1;
addToMetaConfusionMatrix(actual, predicted, recordMetaData.get(i));
}
}
} else {
final INDArray guessIndex = Nd4j.argMax(guesses, 1);
final INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
int nExamples = guessIndex.length();
for (int i = 0; i < nExamples; i++) {
int actual = (int) realOutcomeIndex.getDouble(i);
int predicted = (int) guessIndex.getDouble(i);
confusion().add(actual, predicted);
if (recordMetaData != null && recordMetaData.size() > i) {
Object m = recordMetaData.get(i);
addToMetaConfusionMatrix(actual, predicted, m);
}
// instead of looping through each label for confusion
// matrix, instead infer those values by determining if true/false negative/positive,
// then just add across matrix
// if actual == predicted, then it's a true positive, assign true negative to every other label
if(actual == predicted) {
truePositives.incrementCount(actual, 1);
for (int col = 0; col < nCols; col++) {
if(col == actual){
continue;
}
trueNegatives.incrementCount(col, 1); // all cols prior
}
} else {
falsePositives.incrementCount(predicted, 1);
falseNegatives.incrementCount(actual, 1);
// first determine intervals for adding true negatives
int lesserIndex, greaterIndex;
if (actual < predicted) {
lesserIndex = actual;
greaterIndex = predicted;
} else {
lesserIndex = predicted;
greaterIndex = actual;
}
// now loop through intervals
for (int col = 0; col < lesserIndex; col++){
trueNegatives.incrementCount(col, 1); // all cols prior
}
for (int col = lesserIndex+1; col < greaterIndex; col++){
trueNegatives.incrementCount(col, 1); // all cols after
}
for (int col = greaterIndex+1; col < nCols; col++){
trueNegatives.incrementCount(col, 1); // all cols after
}
}
}
}
if (nCols > 1 && topN > 1) {
//Calculate top N accuracy
//TODO: this could be more efficient
INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
int nExamples = realOutcomeIndex.length();
for (int i = 0; i < nExamples; i++) {
int labelIdx = (int) realOutcomeIndex.getDouble(i);
double prob = guesses.getDouble(i, labelIdx);
INDArray row = guesses.getRow(i);
int countGreaterThan = (int) Nd4j.getExecutioner()
.exec(new MatchCondition(row, Conditions.greaterThan(prob)), Integer.MAX_VALUE)
.getDouble(0);
if (countGreaterThan < topN) {
//For example, for top 3 accuracy: can have at most 2 other probabilities larger
topNCorrectCount++;
}
topNTotalCount++;
}
}
}
/**
* Evaluate a single prediction (one prediction at a time)
*
* @param predictedIdx Index of class predicted by the network
* @param actualIdx Index of actual class
*/
public void eval(int predictedIdx, int actualIdx) {
// Add the number of rows to numRowCounter
numRowCounter++;
// If confusion is null, then Evaluation is instantiated without providing the classes
if (confusion == null) {
throw new UnsupportedOperationException(
"Cannot evaluate single example without initializing confusion matrix first");
}
addToConfusion(actualIdx, predictedIdx);
// If they are equal
if (predictedIdx == actualIdx) {
// Then add 1 to True Positive
// (For a particular label)
incrementTruePositives(predictedIdx);
// And add 1 for each negative class that is accurately predicted (True Negative)
//(For a particular label)
for (Integer clazz : confusion().getClasses()) {
if (clazz != predictedIdx)
trueNegatives.incrementCount(clazz, 1.0);
}
} else {
// Otherwise the real label is predicted as negative (False Negative)
incrementFalseNegatives(actualIdx);
// Otherwise the prediction is predicted as falsely positive (False Positive)
incrementFalsePositives(predictedIdx);
// Otherwise true negatives
for (Integer clazz : confusion().getClasses()) {
if (clazz != predictedIdx && clazz != actualIdx)
trueNegatives.incrementCount(clazz, 1.0);
}
}
}
public String stats() {
return stats(false);
}
/**
* Method to obtain the classification report as a String
*
* @param suppressWarnings whether or not to output warnings related to the evaluation results
* @return A (multi-line) String with accuracy, precision, recall, f1 score etc
*/
public String stats(boolean suppressWarnings) {
String actual, expected;
StringBuilder builder = new StringBuilder().append("\n");
StringBuilder warnings = new StringBuilder();
List<Integer> classes = confusion().getClasses();
List<Integer> falsePositivesWarningClasses = new ArrayList<>();
List<Integer> falseNegativesWarningClasses = new ArrayList<>();
for (Integer clazz : classes) {
actual = resolveLabelForClass(clazz);
//Output confusion matrix
for (Integer clazz2 : classes) {
int count = confusion().getCount(clazz, clazz2);
if (count != 0) {
expected = resolveLabelForClass(clazz2);
builder.append(String.format("Examples labeled as %s classified by model as %s: %d times%n", actual,
expected, count));
}
}
//Output possible warnings regarding precision/recall calculation
if (!suppressWarnings && truePositives.getCount(clazz) == 0) {
if (falsePositives.getCount(clazz) == 0) {
falsePositivesWarningClasses.add(clazz);
}
if (falseNegatives.getCount(clazz) == 0) {
falseNegativesWarningClasses.add(clazz);
}
}
}
if(falsePositivesWarningClasses.size() > 0){
warningHelper(warnings, falsePositivesWarningClasses, "precision");
}
if(falseNegativesWarningClasses.size() > 0){
warningHelper(warnings, falseNegativesWarningClasses, "recall");
}
builder.append("\n");
builder.append(warnings);
int nClasses = confusion().getClasses().size();
DecimalFormat df = new DecimalFormat("0.0000");
double acc = accuracy();
double precisionMacro = precision(EvaluationAveraging.Macro);
double recallMacro = recall(EvaluationAveraging.Macro);
double f1Macro = f1(EvaluationAveraging.Macro);
builder.append("\n==========================Scores========================================");
builder.append("\n # of classes: ").append(nClasses);
builder.append("\n Accuracy: ").append(format(df, acc));
if (topN > 1) {
double topNAcc = topNAccuracy();
builder.append("\n Top ").append(topN).append(" Accuracy: ").append(format(df, topNAcc));
}
builder.append("\n Precision: ").append(format(df, precisionMacro));
if(nClasses > 2 && averagePrecisionNumClassesExcluded() > 0){
int ex = averagePrecisionNumClassesExcluded();
builder.append("\t(").append(ex).append(" class");
if(ex > 1) builder.append("es");
builder.append(" excluded from average)");
}
builder.append("\n Recall: ").append(format(df, recallMacro));
if(nClasses > 2 && averageRecallNumClassesExcluded() > 0){
int ex = averageRecallNumClassesExcluded();
builder.append("\t(").append(ex).append(" class");
if(ex > 1) builder.append("es");
builder.append(" excluded from average)");
}
builder.append("\n F1 Score: ").append(format(df, f1Macro));
if(nClasses > 2 && averageF1NumClassesExcluded() > 0){
int ex = averageF1NumClassesExcluded();
builder.append("\t(").append(ex).append(" class");
if(ex > 1) builder.append("es");
builder.append(" excluded from average)");
}
if(nClasses > 2){
builder.append("\nPrecision, recall & F1: macro-averaged (equally weighted avg. of ").append(nClasses).append(" classes)");
}
//Note that we could report micro-averaged too - but these are the same as accuracy
//"Note that for “micro”-averaging in a multiclass setting with all labels included will produce equal precision, recall and F,"
//http://scikit-learn.org/stable/modules/model_evaluation.html
builder.append("\n========================================================================");
return builder.toString();
}
private static String format(DecimalFormat f, double num) {
if (Double.isNaN(num) || Double.isInfinite(num))
return String.valueOf(num);
return f.format(num);
}
private String resolveLabelForClass(Integer clazz) {
if (labelsList != null && labelsList.size() > clazz)
return labelsList.get(clazz);
return clazz.toString();
}
private void warningHelper(StringBuilder warnings, List<Integer> list, String metric ){
warnings.append("Warning: ").append(list.size()).append(" class");
String wasWere;
if(list.size() == 1) {
wasWere = "was";
} else {
wasWere = "were";
warnings.append("es");
}
warnings.append(" ").append(wasWere);
warnings.append(" never predicted by the model and ").append(wasWere).append(" excluded from average ")
.append(metric).append("\nClasses excluded from average ").append(metric).append(": ")
.append(list)
.append("\n");
}
/**
* Returns the precision for a given label
*
* @param classLabel the label
* @return the precision for the label
*/
public double precision(Integer classLabel) {
return precision(classLabel, DEFAULT_EDGE_VALUE);
}
/**
* Returns the precision for a given label
*
* @param classLabel the label
* @param edgeCase What to output in case of 0/0
* @return the precision for the label
*/
public double precision(Integer classLabel, double edgeCase) {
double tpCount = truePositives.getCount(classLabel);
double fpCount = falsePositives.getCount(classLabel);
return EvaluationUtils.precision((long)tpCount, (long)fpCount, edgeCase);
}
/**
* Precision based on guesses so far
* Takes into account all known classes and outputs average precision across all of them.
* i.e., is macro-averaged precision, equivalent to {@code precision(EvaluationAveraging.Macro)}
*
* @return the total precision based on guesses so far
*/
public double precision() {
return precision(EvaluationAveraging.Macro);
}
/**
* Calculate the average precision for all classes. Can specify whether macro or micro averaging should be used
* NOTE: if any classes have tp=0 and fp=0, (precision=0/0) these are excluded from the average
*
* @param averaging Averaging method - macro or micro
* @return Average precision
*/
public double precision(EvaluationAveraging averaging){
int nClasses = confusion().getClasses().size();
if(averaging == EvaluationAveraging.Macro){
double macroPrecision = 0.0;
int count = 0;
for( int i=0; i<nClasses; i++ ){
double thisClassPrec = precision(i, -1);
if(thisClassPrec != -1){
macroPrecision += thisClassPrec;
count++;
}
}
macroPrecision /= count;
return macroPrecision;
} else if(averaging == EvaluationAveraging.Micro){
long tpCount = 0;
long fpCount = 0;
for( int i=0; i<nClasses; i++ ){
tpCount += truePositives.getCount(i);
fpCount += falsePositives.getCount(i);
}
return EvaluationUtils.precision(tpCount, fpCount, DEFAULT_EDGE_VALUE);
} else {
throw new UnsupportedOperationException("Unknown averaging approach: " + averaging);
}
}
/**
* When calculating the (macro) average precision, how many classes are excluded from the average due to
* no predictions – i.e., precision would be the edge case of 0/0
*
* @return Number of classes excluded from the average precision
*/
public int averagePrecisionNumClassesExcluded() {
return numClassesExcluded("precision");
}
/**
* When calculating the (macro) average Recall, how many classes are excluded from the average due to
* no predictions – i.e., recall would be the edge case of 0/0
*
* @return Number of classes excluded from the average recall
*/
public int averageRecallNumClassesExcluded(){
return numClassesExcluded("recall");
}
/**
* When calculating the (macro) average F1, how many classes are excluded from the average due to
* no predictions – i.e., F1 would be calculated from a precision or recall of 0/0
*
* @return Number of classes excluded from the average F1
*/
public int averageF1NumClassesExcluded(){
return numClassesExcluded("f1");
}
/**
* When calculating the (macro) average FBeta, how many classes are excluded from the average due to
* no predictions – i.e., FBeta would be calculated from a precision or recall of 0/0
*
* @return Number of classes excluded from the average FBeta
*/
public int averageFBetaNumClassesExcluded(){
return numClassesExcluded("fbeta");
}
private int numClassesExcluded(String metric) {
int countExcluded = 0;
int nClasses = confusion().getClasses().size();
for (int i = 0; i < nClasses; i++) {
double d;
switch (metric.toLowerCase()) {
case "precision":
d = precision(i, -1);
break;
case "recall":
d = recall(i, -1);
break;
case "f1":
case "fbeta":
d = fBeta(1.0, i, -1);
break;
default:
throw new RuntimeException("Unknown metric: " + metric);
}
if (d == -1) {
countExcluded++;
}
}
return countExcluded;
}
/**
* Returns the recall for a given label
*
* @param classLabel the label
* @return Recall rate as a double
*/
public double recall(int classLabel) {
return recall(classLabel, DEFAULT_EDGE_VALUE);
}
/**
* Returns the recall for a given label
*
* @param classLabel the label
* @param edgeCase What to output in case of 0/0
* @return Recall rate as a double
*/
public double recall(int classLabel, double edgeCase) {
double tpCount = truePositives.getCount(classLabel);
double fnCount = falseNegatives.getCount(classLabel);
return EvaluationUtils.recall((long)tpCount, (long)fnCount, edgeCase);
}
/**
* Recall based on guesses so far
* Takes into account all known classes and outputs average recall across all of them
*
* @return the recall for the outcomes
*/
public double recall() {
return recall(EvaluationAveraging.Macro);
}
/**
* Calculate the average recall for all classes - can specify whether macro or micro averaging should be used
* NOTE: if any classes have tp=0 and fn=0, (recall=0/0) these are excluded from the average
*
* @param averaging Averaging method - macro or micro
* @return Average recall
*/
public double recall(EvaluationAveraging averaging){
int nClasses = confusion().getClasses().size();
if(averaging == EvaluationAveraging.Macro){
double macroRecall = 0.0;
int count = 0;
for( int i=0; i<nClasses; i++ ){
double thisClassRecall = recall(i,-1);
if(thisClassRecall != -1){
macroRecall += thisClassRecall;
count++;
}
}
macroRecall /= count;
return macroRecall;
} else if(averaging == EvaluationAveraging.Micro){
long tpCount = 0;
long fnCount = 0;
for( int i=0; i<nClasses; i++ ){
tpCount += truePositives.getCount(i);
fnCount += falseNegatives.getCount(i);
}
return EvaluationUtils.recall(tpCount, fnCount, DEFAULT_EDGE_VALUE);
} else {
throw new UnsupportedOperationException("Unknown averaging approach: " + averaging);
}
}
/**
* Returns the false positive rate for a given label
*
* @param classLabel the label
* @return fpr as a double
*/
public double falsePositiveRate(int classLabel) {
return falsePositiveRate(classLabel, DEFAULT_EDGE_VALUE);
}
/**
* Returns the false positive rate for a given label
*
* @param classLabel the label
* @param edgeCase What to output in case of 0/0
* @return fpr as a double
*/
public double falsePositiveRate(int classLabel, double edgeCase) {
double fpCount = falsePositives.getCount(classLabel);
double tnCount = trueNegatives.getCount(classLabel);
return EvaluationUtils.falsePositiveRate((long)fpCount, (long)tnCount, edgeCase);
}
/**
* False positive rate based on guesses so far
* Takes into account all known classes and outputs average fpr across all of them
*
* @return the fpr for the outcomes
*/
public double falsePositiveRate() {
return falsePositiveRate(EvaluationAveraging.Macro);
}
/**
* Calculate the average false positive rate across all classes. Can specify whether macro or micro averaging should be used
*
* @param averaging Averaging method - macro or micro
* @return Average false positive rate
*/
public double falsePositiveRate(EvaluationAveraging averaging){
int nClasses = confusion().getClasses().size();
if(averaging == EvaluationAveraging.Macro){
double macroFPR = 0.0;
for( int i=0; i<nClasses; i++ ){
macroFPR += falsePositiveRate(i);
}
macroFPR /= nClasses;
return macroFPR;
} else if(averaging == EvaluationAveraging.Micro){
long fpCount = 0;
long tnCount = 0;
for( int i=0; i<nClasses; i++ ){
fpCount += falsePositives.getCount(i);
tnCount += trueNegatives.getCount(i);
}
return EvaluationUtils.falsePositiveRate(fpCount, tnCount, DEFAULT_EDGE_VALUE);
} else {
throw new UnsupportedOperationException("Unknown averaging approach: " + averaging);
}
}
/**
* Returns the false negative rate for a given label
*
* @param classLabel the label
* @return fnr as a double
*/
public double falseNegativeRate(Integer classLabel) {
return falseNegativeRate(classLabel, DEFAULT_EDGE_VALUE);
}
/**
* Returns the false negative rate for a given label
*
* @param classLabel the label
* @param edgeCase What to output in case of 0/0
* @return fnr as a double
*/
public double falseNegativeRate(Integer classLabel, double edgeCase) {
double fnCount = falseNegatives.getCount(classLabel);
double tpCount = truePositives.getCount(classLabel);
return EvaluationUtils.falseNegativeRate((long)fnCount, (long)tpCount, edgeCase);
}
/**
* False negative rate based on guesses so far
* Takes into account all known classes and outputs average fnr across all of them
*
* @return the fnr for the outcomes
*/
public double falseNegativeRate() {
return falseNegativeRate(EvaluationAveraging.Macro);
}
/**
* Calculate the average false negative rate for all classes - can specify whether macro or micro averaging should be used
*
* @param averaging Averaging method - macro or micro
* @return Average false negative rate
*/
public double falseNegativeRate(EvaluationAveraging averaging){
int nClasses = confusion().getClasses().size();
if(averaging == EvaluationAveraging.Macro){
double macroFNR = 0.0;
for( int i=0; i<nClasses; i++ ){
macroFNR += falseNegativeRate(i);
}
macroFNR /= nClasses;
return macroFNR;
} else if(averaging == EvaluationAveraging.Micro){
long fnCount = 0;
long tnCount = 0;
for( int i=0; i<nClasses; i++ ){
fnCount += falseNegatives.getCount(i);
tnCount += trueNegatives.getCount(i);
}
return EvaluationUtils.falseNegativeRate(fnCount, tnCount, DEFAULT_EDGE_VALUE);
} else {
throw new UnsupportedOperationException("Unknown averaging approach: " + averaging);
}
}
/**
* False Alarm Rate (FAR) reflects rate of misclassified to classified records
* http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
*
* @return the fpr for the outcomes
*/
public double falseAlarmRate() {
return (falsePositiveRate() + falseNegativeRate()) / 2.0;
}
/**
* Calculate f1 score for a given class
*
* @param classLabel the label to calculate f1 for
* @return the f1 score for the given label
*/
public double f1(int classLabel) {
return fBeta(1.0, classLabel);
}
/**
* Calculate the f_beta for a given class, where f_beta is defined as:<br>
* (1+beta^2) * (precision * recall) / (beta^2 * precision + recall).<br>
* F1 is a special case of f_beta, with beta=1.0
*
* @param beta Beta value to use
* @param classLabel Class label
* @return F_beta
*/
public double fBeta(double beta, int classLabel) {
return fBeta(beta, classLabel, 0.0);
}
/**
* Calculate the f_beta for a given class, where f_beta is defined as:<br>
* (1+beta^2) * (precision * recall) / (beta^2 * precision + recall).<br>
* F1 is a special case of f_beta, with beta=1.0
*
* @param beta Beta value to use
* @param classLabel Class label
* @param defaultValue Default value to use when precision or recall is undefined (0/0 for prec. or recall)
* @return F_beta
*/
public double fBeta(double beta, int classLabel, double defaultValue){
double precision = precision(classLabel, -1);
double recall = recall(classLabel, -1);
if(precision == -1 || recall == -1){
return defaultValue;
}
return EvaluationUtils.fBeta(beta, precision, recall);
}
/**
* Calculate the (macro) average F1 score across all classes
*
* TP: true positive
* FP: False Positive
* FN: False Negative
* F1 score: 2 * TP / (2TP + FP + FN)
*
* @return the f1 score or harmonic mean of precision and recall based on current guesses
*/
public double f1() {
return f1(EvaluationAveraging.Macro);
}
/**
* Calculate the average F1 score across all classes, using macro or micro averaging
*
* @param averaging Averaging method to use
*/
public double f1(EvaluationAveraging averaging){
return fBeta(1.0, averaging);
}
/**
* Calculate the average F_beta score across all classes, using macro or micro averaging
*
* @param beta Beta value to use
* @param averaging Averaging method to use
*/
public double fBeta(double beta, EvaluationAveraging averaging){
int nClasses = confusion().getClasses().size();
if(nClasses == 2){
return EvaluationUtils.fBeta(beta, (long)truePositives.getCount(1),
(long)falsePositives.getCount(1), (long)falseNegatives.getCount(1));
}
if(averaging == EvaluationAveraging.Macro){
double macroFBeta = 0.0;
int count = 0;
for( int i=0; i<nClasses; i++ ){
double thisFBeta = fBeta(beta,i, -1);
if(thisFBeta != -1){
macroFBeta += thisFBeta;
count++;
}
}
macroFBeta /= count;
return macroFBeta;
} else if(averaging == EvaluationAveraging.Micro){
long tpCount = 0;
long fpCount = 0;
long fnCount = 0;
for( int i=0; i<nClasses; i++ ){
tpCount += truePositives.getCount(i);
fpCount += falsePositives.getCount(i);
fnCount += falseNegatives.getCount(i);
}
return EvaluationUtils.fBeta(beta, tpCount, fpCount, fnCount);
} else {
throw new UnsupportedOperationException("Unknown averaging approach: " + averaging);
}
}
/**
* Calculate the G-measure for the given output
*
* @param output The specified output
* @return The G-measure for the specified output
*/
public double gMeasure(int output){
double precision = precision(output);
double recall = recall(output);
return EvaluationUtils.gMeasure(precision, recall);
}
/**
* Calculates the average G measure for all outputs using micro or macro averaging
*
* @param averaging Averaging method to use
* @return Average G measure
*/
public double gMeasure(EvaluationAveraging averaging){
int nClasses = confusion().getClasses().size();
if(averaging == EvaluationAveraging.Macro){
double macroGMeasure = 0.0;
for( int i=0; i<nClasses; i++ ){
macroGMeasure += gMeasure(i);
}
macroGMeasure /= nClasses;
return macroGMeasure;
} else if(averaging == EvaluationAveraging.Micro){
long tpCount = 0;
long fpCount = 0;
long fnCount = 0;
for( int i=0; i<nClasses; i++ ){
tpCount += truePositives.getCount(i);
fpCount += falsePositives.getCount(i);
fnCount += falseNegatives.getCount(i);
}
double precision = EvaluationUtils.precision(tpCount, fpCount, DEFAULT_EDGE_VALUE);
double recall = EvaluationUtils.recall(tpCount, fnCount, DEFAULT_EDGE_VALUE);
return EvaluationUtils.gMeasure(precision, recall);
} else {
throw new UnsupportedOperationException("Unknown averaging approach: " + averaging);
}
}
/**
* Accuracy:
* (TP + TN) / (P + N)
*
* @return the accuracy of the guesses so far
*/
public double accuracy() {
//Accuracy: sum the counts on the diagonal of the confusion matrix, divide by total
int nClasses = confusion().getClasses().size();
int countCorrect = 0;
for (int i = 0; i < nClasses; i++) {
countCorrect += confusion().getCount(i, i);
}
return countCorrect / (double) getNumRowCounter();
}
/**
* Top N accuracy of the predictions so far. For top N = 1 (default), equivalent to {@link #accuracy()}
* @return Top N accuracy
*/
public double topNAccuracy() {
if (topN <= 1)
return accuracy();
if (topNTotalCount == 0)
return 0.0;
return topNCorrectCount / (double) topNTotalCount;
}
/**
* Calculate the binary Mathews correlation coefficient, for the specified class.<br>
* MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))<br>
*
* @param classIdx Class index to calculate Matthews correlation coefficient for
*/
public double matthewsCorrelation(int classIdx){
return EvaluationUtils.matthewsCorrelation(
(long)truePositives.getCount(classIdx),
(long)falsePositives.getCount(classIdx),
(long)falseNegatives.getCount(classIdx),
(long)trueNegatives.getCount(classIdx));
}
/**
* Calculate the average binary Mathews correlation coefficient, using macro or micro averaging.<br>
* MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))<br>
* Note: This is NOT the same as the multi-class Matthews correlation coefficient
*
* @param averaging Averaging approach
* @return Average
*/
public double matthewsCorrelation(EvaluationAveraging averaging){
int nClasses = confusion().getClasses().size();
if(averaging == EvaluationAveraging.Macro){
double macroMatthewsCorrelation = 0.0;
for( int i=0; i<nClasses; i++ ){
macroMatthewsCorrelation += matthewsCorrelation(i);
}
macroMatthewsCorrelation /= nClasses;
return macroMatthewsCorrelation;
} else if(averaging == EvaluationAveraging.Micro){
long tpCount = 0;
long fpCount = 0;
long fnCount = 0;
long tnCount = 0;
for( int i=0; i<nClasses; i++ ){
tpCount += truePositives.getCount(i);
fpCount += falsePositives.getCount(i);
fnCount += falseNegatives.getCount(i);
tnCount += trueNegatives.getCount(i);
}
return EvaluationUtils.matthewsCorrelation(tpCount, fpCount, fnCount, tnCount);
} else {
throw new UnsupportedOperationException("Unknown averaging approach: " + averaging);
}
}
// Access counter methods
/**
* True positives: correctly rejected
*
* @return the total true positives so far
*/
public Map<Integer, Integer> truePositives() {
return convertToMap(truePositives, confusion().getClasses().size());
}
/**
* True negatives: correctly rejected
*
* @return the total true negatives so far
*/
public Map<Integer, Integer> trueNegatives() {
return convertToMap(trueNegatives, confusion().getClasses().size());
}
/**
* False positive: wrong guess
*
* @return the count of the false positives
*/
public Map<Integer, Integer> falsePositives() {
return convertToMap(falsePositives, confusion().getClasses().size());
}
/**
* False negatives: correctly rejected
*
* @return the total false negatives so far
*/
public Map<Integer, Integer> falseNegatives() {
return convertToMap(falseNegatives, confusion().getClasses().size());
}
/**
* Total negatives true negatives + false negatives
*
* @return the overall negative count
*/
public Map<Integer, Integer> negative() {
return addMapsByKey(trueNegatives(), falsePositives());
}
/**
* Returns all of the positive guesses:
* true positive + false negative
*/
public Map<Integer, Integer> positive() {
return addMapsByKey(truePositives(), falseNegatives());
}
private Map<Integer, Integer> convertToMap(Counter<Integer> counter, int maxCount) {
Map<Integer, Integer> map = new HashMap<>();
for (int i = 0; i < maxCount; i++) {
map.put(i, (int) counter.getCount(i));
}
return map;
}
private Map<Integer, Integer> addMapsByKey(Map<Integer, Integer> first, Map<Integer, Integer> second) {
Map<Integer, Integer> out = new HashMap<>();
Set<Integer> keys = new HashSet<>(first.keySet());
keys.addAll(second.keySet());
for (Integer i : keys) {
Integer f = first.get(i);
Integer s = second.get(i);
if (f == null)
f = 0;
if (s == null)
s = 0;
out.put(i, f + s);
}
return out;
}
// Incrementing counters
public void incrementTruePositives(Integer classLabel) {
truePositives.incrementCount(classLabel, 1.0);
}
public void incrementTrueNegatives(Integer classLabel) {
trueNegatives.incrementCount(classLabel, 1.0);
}
public void incrementFalseNegatives(Integer classLabel) {
falseNegatives.incrementCount(classLabel, 1.0);
}
public void incrementFalsePositives(Integer classLabel) {
falsePositives.incrementCount(classLabel, 1.0);
}
// Other misc methods
/**
* Adds to the confusion matrix
*
* @param real the actual guess
* @param guess the system guess
*/
public void addToConfusion(Integer real, Integer guess) {
confusion().add(real, guess);
}
/**
* Returns the number of times the given label
* has actually occurred
*
* @param clazz the label
* @return the number of times the label
* actually occurred
*/
public int classCount(Integer clazz) {
return confusion().getActualTotal(clazz);
}
public int getNumRowCounter() {
return numRowCounter;
}
/**
* Return the number of correct predictions according to top N value. For top N = 1 (default) this is equivalent to
* the number of correct predictions
* @return Number of correct top N predictions
*/
@JsonIgnore
public int getTopNCorrectCount() {
if(confusion == null)
confusion = new ConfusionMatrix<>();
if (topN <= 1) {
int nClasses = confusion().getClasses().size();
int countCorrect = 0;
for (int i = 0; i < nClasses; i++) {
countCorrect += confusion().getCount(i, i);
}
return countCorrect;
}
return topNCorrectCount;
}
/**
* Return the total number of top N evaluations. Most of the time, this is exactly equal to {@link #getNumRowCounter()},
* but may differ in the case of using {@link #eval(int, int)} as top N accuracy cannot be calculated in that case
* (i.e., requires the full probability distribution, not just predicted/actual indices)
* @return Total number of top N predictions
*/
@JsonIgnore
public int getTopNTotalCount() {
if (topN <= 1) {
return getNumRowCounter();
}
return topNTotalCount;
}
public String getClassLabel(Integer clazz) {
return resolveLabelForClass(clazz);
}
/**
* Returns the confusion matrix variable
*
* @return confusion matrix variable for this evaluation
*/
@JsonIgnore
public ConfusionMatrix<Integer> getConfusionMatrix() {
return confusion;
}
/**
* Merge the other evaluation object into this one. The result is that this Evaluation instance contains the counts
* etc from both
*
* @param other Evaluation object to merge into this one.
*/
@Override
public void merge(Evaluation other) {
if (other == null)
return;
truePositives.incrementAll(other.truePositives);
falsePositives.incrementAll(other.falsePositives);
trueNegatives.incrementAll(other.trueNegatives);
falseNegatives.incrementAll(other.falseNegatives);
if (confusion == null) {
if (other.confusion != null)
confusion = new ConfusionMatrix<>(other.confusion);
} else {
if (other.confusion != null)
confusion().add(other.confusion);
}
numRowCounter += other.numRowCounter;
if (labelsList.isEmpty())
labelsList.addAll(other.labelsList);
if (topN != other.topN) {
log.warn("Different topN values ({} vs {}) detected during Evaluation merging. Top N accuracy may not be accurate.",
topN, other.topN);
}
this.topNCorrectCount += other.topNCorrectCount;
this.topNTotalCount += other.topNTotalCount;
}
/**
* Get a String representation of the confusion matrix
*/
public String confusionToString() {
int nClasses = confusion().getClasses().size();
//First: work out the longest label size
int maxLabelSize = 0;
for (String s : labelsList) {
maxLabelSize = Math.max(maxLabelSize, s.length());
}
//Build the formatting for the rows:
int labelSize = Math.max(maxLabelSize + 5, 10);
StringBuilder sb = new StringBuilder();
sb.append("%-3d");
sb.append("%-");
sb.append(labelSize);
sb.append("s | ");
StringBuilder headerFormat = new StringBuilder();
headerFormat.append(" %-").append(labelSize).append("s ");
for (int i = 0; i < nClasses; i++) {
sb.append("%7d");
headerFormat.append("%7d");
}
String rowFormat = sb.toString();
StringBuilder out = new StringBuilder();
//First: header row
Object[] headerArgs = new Object[nClasses + 1];
headerArgs[0] = "Predicted:";
for (int i = 0; i < nClasses; i++)
headerArgs[i + 1] = i;
out.append(String.format(headerFormat.toString(), headerArgs)).append("\n");
//Second: divider rows
out.append(" Actual:\n");
//Finally: data rows
for (int i = 0; i < nClasses; i++) {
Object[] args = new Object[nClasses + 2];
args[0] = i;
args[1] = labelsList.get(i);
for (int j = 0; j < nClasses; j++) {
args[j + 2] = confusion().getCount(i, j);
}
out.append(String.format(rowFormat, args));
out.append("\n");
}
return out.toString();
}
private void addToMetaConfusionMatrix(int actual, int predicted, Object metaData) {
if (confusionMatrixMetaData == null) {
confusionMatrixMetaData = new HashMap<>();
}
Pair<Integer, Integer> p = new Pair<>(actual, predicted);
List<Object> list = confusionMatrixMetaData.get(p);
if (list == null) {
list = new ArrayList<>();
confusionMatrixMetaData.put(p, list);
}
list.add(metaData);
}
/**
* Get a list of prediction errors, on a per-record basis<br>
* <p>
* <b>Note</b>: Prediction errors are ONLY available if the "evaluate with metadata" method is used: {@link #eval(INDArray, INDArray, List)}
* Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
* splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
* via {@link #getConfusionMatrix()}
*
* @return A list of prediction errors, or null if no metadata has been recorded
*/
@JsonIgnore
public List<Prediction> getPredictionErrors() {
if (this.confusionMatrixMetaData == null)
return null;
List<Prediction> list = new ArrayList<>();
List<Map.Entry<Pair<Integer, Integer>, List<Object>>> sorted =
new ArrayList<>(confusionMatrixMetaData.entrySet());
Collections.sort(sorted, new Comparator<Map.Entry<Pair<Integer, Integer>, List<Object>>>() {
@Override
public int compare(Map.Entry<Pair<Integer, Integer>, List<Object>> o1,
Map.Entry<Pair<Integer, Integer>, List<Object>> o2) {
Pair<Integer, Integer> p1 = o1.getKey();
Pair<Integer, Integer> p2 = o2.getKey();
int order = Integer.compare(p1.getFirst(), p2.getFirst());
if (order != 0)
return order;
order = Integer.compare(p1.getSecond(), p2.getSecond());
return order;
}
});
for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : sorted) {
Pair<Integer, Integer> p = entry.getKey();
if (p.getFirst().equals(p.getSecond())) {
//predicted = actual -> not an error -> skip
continue;
}
for (Object m : entry.getValue()) {
list.add(new Prediction(p.getFirst(), p.getSecond(), m));
}
}
return list;
}
/**
* Get a list of predictions, for all data with the specified <i>actual</i> class, regardless of the predicted
* class.
* <p>
* <b>Note</b>: Prediction errors are ONLY available if the "evaluate with metadata" method is used: {@link #eval(INDArray, INDArray, List)}
* Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
* splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
* via {@link #getConfusionMatrix()}
*
* @param actualClass Actual class to get predictions for
* @return List of predictions, or null if the "evaluate with metadata" method was not used
*/
public List<Prediction> getPredictionsByActualClass(int actualClass) {
if (confusionMatrixMetaData == null)
return null;
List<Prediction> out = new ArrayList<>();
for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : confusionMatrixMetaData.entrySet()) { //Entry Pair: (Actual,Predicted)
if (entry.getKey().getFirst() == actualClass) {
int actual = entry.getKey().getFirst();
int predicted = entry.getKey().getSecond();
for (Object m : entry.getValue()) {
out.add(new Prediction(actual, predicted, m));
}
}
}
return out;
}
/**
* Get a list of predictions, for all data with the specified <i>predicted</i> class, regardless of the actual data
* class.
* <p>
* <b>Note</b>: Prediction errors are ONLY available if the "evaluate with metadata" method is used: {@link #eval(INDArray, INDArray, List)}
* Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
* splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
* via {@link #getConfusionMatrix()}
*
* @param predictedClass Actual class to get predictions for
* @return List of predictions, or null if the "evaluate with metadata" method was not used
*/
public List<Prediction> getPredictionByPredictedClass(int predictedClass) {
if (confusionMatrixMetaData == null)
return null;
List<Prediction> out = new ArrayList<>();
for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : confusionMatrixMetaData.entrySet()) { //Entry Pair: (Actual,Predicted)
if (entry.getKey().getSecond() == predictedClass) {
int actual = entry.getKey().getFirst();
int predicted = entry.getKey().getSecond();
for (Object m : entry.getValue()) {
out.add(new Prediction(actual, predicted, m));
}
}
}
return out;
}
/**
* Get a list of predictions in the specified confusion matrix entry (i.e., for the given actua/predicted class pair)
*
* @param actualClass Actual class
* @param predictedClass Predicted class
* @return List of predictions that match the specified actual/predicted classes, or null if the "evaluate with metadata" method was not used
*/
public List<Prediction> getPredictions(int actualClass, int predictedClass) {
if (confusionMatrixMetaData == null)
return null;
List<Prediction> out = new ArrayList<>();
List<Object> list = confusionMatrixMetaData.get(new Pair<>(actualClass, predictedClass));
if (list == null)
return out;
for (Object meta : list) {
out.add(new Prediction(actualClass, predictedClass, meta));
}
return out;
}
}