package org.deeplearning4j.eval;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import java.util.*;
/**
* ROC (Receiver Operating Characteristic) for multi-class classifiers, using the specified number of threshold steps.
* <p>
* The ROC curves are produced by treating the predictions as a set of one-vs-all classifiers, and then calculating
* ROC curves for each. In practice, this means for N classes, we get N ROC curves.
* <p>
* Some ROC implementations will automatically calculate the threshold points based on the data set to give a 'smoother'
* ROC curve (or optimal cut points for diagnostic purposes). This implementation currently uses fixed steps of size
* 1.0 / thresholdSteps, as this allows easy implementation for batched and distributed evaluation scenarios (where the
* full data set is not available in memory on any one machine at once).
*
* @author Alex Black
*/
@Getter
@EqualsAndHashCode(callSuper = true)
@NoArgsConstructor
public class ROCMultiClass extends BaseEvaluation<ROCMultiClass> {
private int thresholdSteps;
private long[] countActualPositive;
private long[] countActualNegative;
private final Map<Integer, Map<Double, ROC.CountsForThreshold>> counts = new LinkedHashMap<>();
/**
* @param thresholdSteps Number of threshold steps to use for the ROC calculation
*/
public ROCMultiClass(int thresholdSteps) {
this.thresholdSteps = thresholdSteps;
}
@Override
public void reset() {
countActualPositive = null;
countActualNegative = null;
}
@Override
public String stats() {
return "Average AUC: [" + calculateAverageAUC() + "]";
}
/**
* Evaluate (collect statistics for) the given minibatch of data.
* For time series (3 dimensions) use {@link #evalTimeSeries(INDArray, INDArray)} or {@link #evalTimeSeries(INDArray, INDArray, INDArray)}
*
* @param labels Labels / true outcomes
* @param predictions Predictions
*/
public void eval(INDArray labels, INDArray predictions) {
if (labels.rank() == 3 && predictions.rank() == 3) {
//Assume time series input -> reshape to 2d
evalTimeSeries(labels, predictions);
}
if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1)) {
throw new IllegalArgumentException("Invalid input data shape: labels shape = "
+ Arrays.toString(labels.shape()) + ", predictions shape = "
+ Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2");
}
double step = 1.0 / thresholdSteps;
if (countActualPositive == null) {
//This must be the first time eval has been called...
int size = labels.size(1);
countActualPositive = new long[size];
countActualNegative = new long[size];
for (int i = 0; i < size; i++) {
Map<Double, ROC.CountsForThreshold> map = new LinkedHashMap<Double, ROC.CountsForThreshold>();
counts.put(i, map);
for (int j = 0; j <= thresholdSteps; j++) {
double currThreshold = j * step;
map.put(currThreshold, new ROC.CountsForThreshold(currThreshold));
}
}
}
if (countActualPositive.length != labels.size(1)) {
throw new IllegalArgumentException(
"Cannot evaluate data: number of label classes does not match previous call. " + "Got "
+ labels.size(1) + " labels (from array shape "
+ Arrays.toString(labels.shape()) + ")"
+ " vs. expected number of label classes = " + countActualPositive.length);
}
for (int i = 0; i < countActualPositive.length; i++) {
//Iterate over each class
INDArray positiveActualColumn = labels.getColumn(i);
INDArray positivePredictedColumn = predictions.getColumn(i);
//Increment global counts - actual positive/negative observed
long currBatchPositiveActualCount = positiveActualColumn.sumNumber().intValue();
countActualPositive[i] += currBatchPositiveActualCount;
countActualNegative[i] += positiveActualColumn.length() - currBatchPositiveActualCount;
//Here: calculate true positive rate (TPR) vs. false positive rate (FPR) at different threshold
for (int j = 0; j <= thresholdSteps; j++) {
double currThreshold = j * step;
//Work out true/false positives - do this by replacing probabilities (predictions) with 1 or 0 based on threshold
Condition condGeq = Conditions.greaterThanOrEqual(currThreshold);
Condition condLeq = Conditions.lessThanOrEqual(currThreshold);
Op op = new CompareAndSet(positivePredictedColumn.dup(), 1.0, condGeq);
INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
op = new CompareAndSet(predictedClass1, 0.0, condLeq);
predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
//True positives: occur when positive predicted class and actual positive actual class...
//False positive occurs when positive predicted class, but negative actual class
INDArray isTruePositive = predictedClass1.mul(positiveActualColumn); //If predicted == 1 and actual == 1 at this threshold: 1x1 = 1. 0 otherwise
INDArray negativeActualColumn = positiveActualColumn.rsub(1.0);
INDArray isFalsePositive = predictedClass1.mul(negativeActualColumn); //If predicted == 1 and actual == 0 at this threshold: 1x1 = 1. 0 otherwise
//Counts for this batch:
int truePositiveCount = isTruePositive.sumNumber().intValue();
int falsePositiveCount = isFalsePositive.sumNumber().intValue();
//Increment counts for this threshold
ROC.CountsForThreshold thresholdCounts = counts.get(i).get(currThreshold);
thresholdCounts.incrementTruePositive(truePositiveCount);
thresholdCounts.incrementFalsePositive(falsePositiveCount);
}
}
}
/**
* Get the ROC curve, as a set of points
*
* @param classIdx Index of the class to get the (one-vs-all) ROC cur
*
* @return ROC curve, as a list of points
*/
public List<ROC.ROCValue> getResults(int classIdx) {
assertHasBeenFit(classIdx);
List<ROC.ROCValue> out = new ArrayList<>(counts.size());
for (Map.Entry<Double, ROC.CountsForThreshold> entry : counts.get(classIdx).entrySet()) {
double t = entry.getKey();
ROC.CountsForThreshold c = entry.getValue();
double tpr = c.getCountTruePositive() / ((double) countActualPositive[classIdx]);
double fpr = c.getCountFalsePositive() / ((double) countActualNegative[classIdx]);
out.add(new ROC.ROCValue(t, tpr, fpr));
}
return out;
}
/**
* Get the ROC curve, as a set of (falsePositive, truePositive) points
* <p>
* Returns a 2d array of {falsePositive, truePositive values}.<br>
* Size is [2][thresholdSteps], with out[0][.] being false positives, and out[1][.] being true positives
*
* @return ROC curve as double[][]
*/
public double[][] getResultsAsArray(int classIdx) {
assertHasBeenFit(classIdx);
double[][] out = new double[2][thresholdSteps + 1];
int i = 0;
for (Map.Entry<Double, ROC.CountsForThreshold> entry : counts.get(classIdx).entrySet()) {
ROC.CountsForThreshold c = entry.getValue();
double tpr = c.getCountTruePositive() / ((double) countActualPositive[classIdx]);
double fpr = c.getCountFalsePositive() / ((double) countActualNegative[classIdx]);
out[0][i] = fpr;
out[1][i] = tpr;
i++;
}
return out;
}
/**
* Calculate the AUC - Area Under ROC Curve<br>
* Utilizes trapezoidal integration internally
*
* @return AUC
*/
public double calculateAUC(int classIdx) {
assertHasBeenFit(classIdx);
//Calculate AUC using trapezoidal rule
List<ROC.ROCValue> list = getResults(classIdx);
//Given the points
double auc = 0.0;
for (int i = 0; i < list.size() - 1; i++) {
ROC.ROCValue left = list.get(i);
ROC.ROCValue right = list.get(i + 1);
//y axis: TPR
//x axis: FPR
double deltaX = Math.abs(right.getFalsePositiveRate() - left.getFalsePositiveRate()); //Iterating in threshold order, so FPR decreases as threshold increases
double avg = (left.getTruePositiveRate() + right.getTruePositiveRate()) / 2.0;
auc += deltaX * avg;
}
return auc;
}
/**
* Calculate the AUPRC - Area Under Curve Precision Recall <br>
* Utilizes trapezoidal integration internally
*
* @return AUC
*/
public double calculateAUCPR(int classIdx) {
assertHasBeenFit(classIdx);
//Calculate AUCPR using trapezoidal rule
List<ROC.PrecisionRecallPoint> prCurve = getPrecisionRecallCurve(classIdx);
//Trapezoidal integration
double aucpr = 0.0;
for (int i = 0; i < prCurve.size()-1; i++) {
ROC.PrecisionRecallPoint p = prCurve.get(i);
double x0 = prCurve.get(i).getRecall();
double x1 = prCurve.get(i+1).getRecall();
double deltaX = x1 - x0;
double y0 = prCurve.get(i).getPrecision();
double y1 = prCurve.get(i+1).getPrecision();
double avgY = (y0+y1) / 2.0;
aucpr += deltaX*avgY;
}
return aucpr;
}
/**
* Calculate the average (one-vs-all) AUC for all classes
*/
public double calculateAverageAUC() {
assertHasBeenFit(0);
double sum = 0.0;
for (int i = 0; i < countActualPositive.length; i++) {
sum += calculateAUC(i);
}
return sum / countActualPositive.length;
}
public List<ROC.PrecisionRecallPoint> getPrecisionRecallCurve(int classIndex) {
//Precision: (true positive count) / (true positive count + false positive count) == true positive rate
//Recall: (true positive count) / (true positive count + false negative count) = (TP count) / (total dataset positives)
List<ROC.PrecisionRecallPoint> out = new ArrayList<>(counts.get(classIndex).size());
for (Map.Entry<Double, ROC.CountsForThreshold> entry : counts.get(classIndex).entrySet()) {
double t = entry.getKey();
ROC.CountsForThreshold c = entry.getValue();
long tpCount = c.getCountTruePositive();
long fpCount = c.getCountFalsePositive();
//For edge cases: http://stats.stackexchange.com/questions/1773/what-are-correct-values-for-precision-and-recall-in-edge-cases
//precision == 1 when FP = 0 -> no incorrect positive predictions
//recall == 1 when no dataset positives are present (got all 0 of 0 positives)
double precision;
if (tpCount == 0 && fpCount == 0) {
//At this threshold: no predicted positive cases
precision = 1.0;
} else {
precision = tpCount / (double) (tpCount + fpCount);
}
double recall;
if (countActualPositive[classIndex] == 0) {
recall = 1.0;
} else {
recall = tpCount / ((double) countActualPositive[classIndex]);
}
out.add(new ROC.PrecisionRecallPoint(c.getThreshold(), precision, recall));
}
return out;
}
/**
* Merge this ROCMultiClass instance with another.
* This ROCMultiClass instance is modified, by adding the stats from the other instance.
*
* @param other ROCMultiClass instance to combine with this one
*/
@Override
public void merge(ROCMultiClass other) {
if (other.countActualPositive == null) {
//Other has no data
return;
} else if (countActualPositive == null) {
//This instance has no data
this.countActualPositive = Arrays.copyOf(other.countActualPositive, other.countActualPositive.length);
this.countActualNegative = Arrays.copyOf(other.countActualNegative, other.countActualNegative.length);
for (Map.Entry<Integer, Map<Double, ROC.CountsForThreshold>> e : other.counts.entrySet()) {
Map<Double, ROC.CountsForThreshold> m = e.getValue();
Map<Double, ROC.CountsForThreshold> mClone = new LinkedHashMap<>();
for (Map.Entry<Double, ROC.CountsForThreshold> e2 : m.entrySet()) {
mClone.put(e2.getKey(), e2.getValue().clone());
}
this.counts.put(e.getKey(), mClone);
}
} else {
for (int i = 0; i < countActualPositive.length; i++) {
this.countActualPositive[i] += other.countActualPositive[i];
this.countActualNegative[i] += other.countActualNegative[i];
}
for (Integer i : counts.keySet()) {
Map<Double, ROC.CountsForThreshold> thisMap = counts.get(i);
Map<Double, ROC.CountsForThreshold> otherMap = other.counts.get(i);
for (Double d : thisMap.keySet()) {
ROC.CountsForThreshold thisC = thisMap.get(d);
ROC.CountsForThreshold otherC = otherMap.get(d);
thisC.incrementTruePositive(otherC.getCountTruePositive());
thisC.incrementFalsePositive(otherC.getCountFalsePositive());
}
}
}
}
private void assertHasBeenFit(int classIdx) {
if (countActualPositive == null) {
throw new IllegalStateException("Cannot get results: no data has been collected");
}
if (classIdx < 0 || classIdx >= countActualPositive.length) {
throw new IllegalArgumentException("Invalid class index (" + classIdx
+ "): must be in range 0 to numClasses = " + countActualPositive.length);
}
}
}