package org.deeplearning4j.eval; import lombok.AllArgsConstructor; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.apache.commons.lang3.ArrayUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.impl.transforms.Not; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Condition; import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.shade.jackson.annotation.JsonIgnore; import java.io.Serializable; import java.util.*; /** * ROC (Receiver Operating Characteristic) for multi-task binary classifiers, using the specified number of threshold steps. * <p> * Some ROC implementations will automatically calculate the threshold points based on the data set to give a 'smoother' * ROC curve (or optimal cut points for diagnostic purposes). This implementation currently uses fixed steps of size * 1.0 / thresholdSteps, as this allows easy implementation for batched and distributed evaluation scenarios (where the * full data set is not available in memory on any one machine at once). * <p> * Unlike {@link ROC} (which supports a single binary label (as a single column probability, or 2 column 'softmax' probability * distribution), ROCBinary assumes that all outputs are independent binary variables. This also differs from * {@link ROCMultiClass}, which should be used for multi-class (single non-binary) cases. * <p> * ROCBinary supports per-example and per-output masking: for per-output masking, any particular output may be absent * (mask value 0) and hence won't be included in the calculated ROC. */ @EqualsAndHashCode(callSuper = true) @Data @NoArgsConstructor public class ROCBinary extends BaseEvaluation<ROCBinary> { public static final int DEFAULT_PRECISION = 4; private int thresholdSteps; private long[] countActualPositive; private long[] countActualNegative; private Map<Double, CountsForThreshold> countsForThresholdMap; private List<String> labels; public ROCBinary(int thresholdSteps) { this.thresholdSteps = thresholdSteps; countActualNegative = null; countActualPositive = null; } @Override public void reset() { countActualPositive = null; countActualNegative = null; countsForThresholdMap = null; } @Override public void eval(INDArray labels, INDArray networkPredictions) { eval(labels, networkPredictions, (INDArray) null); } @Override public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray) { if (countActualPositive != null && countActualPositive.length != labels.size(1)) { throw new IllegalStateException("Labels array does not match stored state size. Expected labels array with " + "size " + countActualPositive.length + ", got labels array with size " + labels.size(1)); } if (labels.rank() == 3) { evalTimeSeries(labels, networkPredictions, maskArray); return; } if (countActualPositive == null) { //Initialize countActualPositive = new long[labels.size(1)]; countActualNegative = new long[labels.size(1)]; countsForThresholdMap = new LinkedHashMap<>(); double step = 1.0 / thresholdSteps; for (int i = 0; i <= thresholdSteps; i++) { double currThreshold = i * step; countsForThresholdMap.put(currThreshold, new CountsForThreshold(currThreshold, labels.size(1))); } } //First: need to increment actual positive/negative (label counts) for each output INDArray actual1 = labels; INDArray actual0 = Nd4j.getExecutioner().execAndReturn(new Not(labels.dup())); if (maskArray != null) { actual1 = actual1.mul(maskArray); actual0.muli(maskArray); } int[] countActualPosThisBatch = actual1.sum(0).data().asInt(); int[] countActualNegThisBatch = actual0.sum(0).data().asInt(); addInPlace(countActualPositive, countActualPosThisBatch); addInPlace(countActualNegative, countActualNegThisBatch); //Here: calculate true positive rate (TPR) vs. false positive rate (FPR) at different threshold double step = 1.0 / thresholdSteps; for (int i = 0; i <= thresholdSteps; i++) { double currThreshold = i * step; //Work out true/false positives - do this by replacing probabilities (predictions) with 1 or 0 based on threshold Condition condGeq = Conditions.greaterThanOrEqual(currThreshold); Condition condLeq = Conditions.lessThanOrEqual(currThreshold); Op op = new CompareAndSet(networkPredictions.dup(), 1.0, condGeq); INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn(op); op = new CompareAndSet(predictedClass1, 0.0, condLeq); predictedClass1 = Nd4j.getExecutioner().execAndReturn(op); //True positives: occur whet the predicted and actual are both 1s //False positives: occur when predicted 1, actual is 0 INDArray isTruePositive = predictedClass1.mul(actual1); INDArray isFalsePositive = predictedClass1.mul(actual0); //Apply mask array: if (maskArray != null) { if (Arrays.equals(labels.shape(), maskArray.shape())) { //Per output masking isTruePositive.muli(maskArray); isFalsePositive.muli(maskArray); } else { //Per-example masking isTruePositive.muliColumnVector(maskArray); isFalsePositive.muliColumnVector(maskArray); } } //TP/FP counts for this threshold int[] truePositiveCount = isTruePositive.sum(0).data().asInt(); int[] falsePositiveCount = isFalsePositive.sum(0).data().asInt(); CountsForThreshold cft = countsForThresholdMap.get(currThreshold); cft.incrementTruePositive(truePositiveCount); cft.incrementFalsePositive(falsePositiveCount); } } private static void addInPlace(long[] addTo, int[] toAdd) { for (int i = 0; i < addTo.length; i++) { addTo[i] += toAdd[i]; } } private static void addInPlace(long[] addTo, long[] toAdd) { for (int i = 0; i < addTo.length; i++) { addTo[i] += toAdd[i]; } } @Override public void merge(ROCBinary other) { if (this.countActualPositive == null) { this.countActualPositive = other.countActualPositive; this.countActualNegative = other.countActualNegative; this.countsForThresholdMap = other.countsForThresholdMap; return; } else if (other.countActualPositive == null) { return; } if (this.countActualPositive.length != other.countActualPositive.length) { throw new IllegalStateException("Cannot merge ROCBinary instances with different number of coulmns. " + "numColumns = " + this.countActualPositive.length + "; other numColumns = " + other.countActualPositive.length); } //Both have data addInPlace(this.countActualPositive, other.countActualPositive); addInPlace(this.countActualNegative, other.countActualNegative); for (Map.Entry<Double, CountsForThreshold> e : countsForThresholdMap.entrySet()) { CountsForThreshold o = other.countsForThresholdMap.get(e.getKey()); e.getValue().incrementTruePositive(o.getCountTruePositive()); e.getValue().incrementFalsePositive(o.getCountFalsePositive()); } } private void assertIndex(int outputNum) { if (countActualPositive == null) { throw new UnsupportedOperationException("ROCBinary does not have any stats: eval must be called first"); } if (outputNum < 0 || outputNum >= countActualPositive.length) { throw new IllegalArgumentException("Invalid input: output number must be between 0 and " + (outputNum - 1)); } } /** * Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known. Returns -1 otherwise */ public int numLabels() { if (countActualPositive == null) { return -1; } return countActualPositive.length; } /** * Get the actual positive count (accounting for any masking) for the specified output/column * * @param outputNum Index of the output (0 to {@link #numLabels()}-1) */ public long getCountActualPositive(int outputNum) { assertIndex(outputNum); return countActualPositive[outputNum]; } /** * Get the actual negative count (accounting for any masking) for the specified output/column * * @param outputNum Index of the output (0 to {@link #numLabels()}-1) */ public long getCountActualNegative(int outputNum) { assertIndex(outputNum); return countActualNegative[outputNum]; } /** * Get the ROC curve, as a set of points * * @param outputNum Index of the output (0 to {@link #numLabels()}-1) * @return ROC curve, as a list of points */ public List<ROCBinary.ROCValue> getResults(int outputNum) { assertIndex(outputNum); List<ROCBinary.ROCValue> out = new ArrayList<>(countsForThresholdMap.size()); for (Map.Entry<Double, ROCBinary.CountsForThreshold> entry : countsForThresholdMap.entrySet()) { double t = entry.getKey(); ROCBinary.CountsForThreshold c = entry.getValue(); double tpr = c.getCountTruePositive()[outputNum] / ((double) countActualPositive[outputNum]); double fpr = c.getCountFalsePositive()[outputNum] / ((double) countActualNegative[outputNum]); out.add(new ROCBinary.ROCValue(t, tpr, fpr)); } return out; } /** * Get the precision/recall curve, for the specified output * * @param outputNum Index of the output (0 to {@link #numLabels()}-1) * @return the precision/recall curve */ public List<ROCBinary.PrecisionRecallPoint> getPrecisionRecallCurve(int outputNum) { assertIndex(outputNum); //Precision: (true positive count) / (true positive count + false positive count) == true positive rate //Recall: (true positive count) / (true positive count + false negative count) = (TP count) / (total dataset positives) List<ROCBinary.PrecisionRecallPoint> out = new ArrayList<>(countsForThresholdMap.size()); for (Map.Entry<Double, ROCBinary.CountsForThreshold> entry : countsForThresholdMap.entrySet()) { double t = entry.getKey(); ROCBinary.CountsForThreshold c = entry.getValue(); long tpCount = c.getCountTruePositive()[outputNum]; long fpCount = c.getCountFalsePositive()[outputNum]; //For edge cases: http://stats.stackexchange.com/questions/1773/what-are-correct-values-for-precision-and-recall-in-edge-cases //precision == 1 when FP = 0 -> no incorrect positive predictions //recall == 1 when no dataset positives are present (got all 0 of 0 positives) double precision; if (tpCount == 0 && fpCount == 0) { //At this threshold: no predicted positive cases precision = 1.0; } else { precision = tpCount / (double) (tpCount + fpCount); } double recall; if (countActualPositive[outputNum] == 0) { recall = 1.0; } else { recall = tpCount / ((double) countActualPositive[outputNum]); } out.add(new ROCBinary.PrecisionRecallPoint(c.getThreshold(), precision, recall)); } return out; } /** * Get the ROC curve, as a set of (falsePositive, truePositive) points * <p> * Returns a 2d array of {falsePositive, truePositive values}.<br> * Size is [2][thresholdSteps], with out[0][.] being false positives, and out[1][.] being true positives * * @return ROC curve as double[][] */ @JsonIgnore public double[][] getResultsAsArray(int outputNum) { assertIndex(outputNum); double[][] out = new double[2][thresholdSteps + 1]; int i = 0; for (Map.Entry<Double, ROCBinary.CountsForThreshold> entry : countsForThresholdMap.entrySet()) { ROCBinary.CountsForThreshold c = entry.getValue(); double tpr = c.getCountTruePositive()[outputNum] / ((double) countActualPositive[outputNum]); double fpr = c.getCountFalsePositive()[outputNum] / ((double) countActualNegative[outputNum]); out[0][i] = fpr; out[1][i] = tpr; i++; } return out; } /** * Average AUC for all outcomes * @return the average AUC for all outcomes. */ public double calculateAverageAuc() { double ret = 0.0; for(int i = 0; i < numLabels(); i++) { ret += calculateAUC(i); } return ret / (double) numLabels(); } /** * Calculate the AUC - Area Under Curve<br> * Utilizes trapezoidal integration internally * * @param outputNum * @return AUC */ public double calculateAUC(int outputNum) { assertIndex(outputNum); //Calculate AUC using trapezoidal rule List<ROCBinary.ROCValue> list = getResults(outputNum); //Given the points double auc = 0.0; for (int i = 0; i < list.size() - 1; i++) { ROCBinary.ROCValue left = list.get(i); ROCBinary.ROCValue right = list.get(i + 1); //y axis: TPR //x axis: FPR double deltaX = Math.abs(right.getFalsePositiveRate() - left.getFalsePositiveRate()); //Iterating in threshold order, so FPR decreases as threshold increases double avg = (left.getTruePositiveRate() + right.getTruePositiveRate()) / 2.0; auc += deltaX * avg; } return auc; } /** * Set the label names, for printing via {@link #stats()} */ public void setLabelNames(List<String> labels) { if (labels == null) { this.labels = null; return; } this.labels = new ArrayList<>(labels); } @Override public String stats() { return stats(DEFAULT_PRECISION); } public String stats(int printPrecision) { //Calculate AUC and also print counts, for each output StringBuilder sb = new StringBuilder(); int maxLabelsLength = 15; if (labels != null) { for (String s : labels) { maxLabelsLength = Math.max(s.length(), maxLabelsLength); } } String patternHeader = "%-" + (maxLabelsLength + 5) + "s%-12s%-10s%-10s"; String header = String.format(patternHeader, "Label", "AUC", "# Pos", "# Neg"); String pattern = "%-" + (maxLabelsLength + 5) + "s" //Label + "%-12." + printPrecision + "f" //AUC + "%-10d%-10d"; //Count pos, count neg sb.append(header); for (int i = 0; i < countActualPositive.length; i++) { double auc = calculateAUC(i); String label = (labels == null ? String.valueOf(i) : labels.get(i)); sb.append("\n").append(String.format(pattern, label, auc, countActualPositive[i], countActualNegative[i])); } return sb.toString(); } @AllArgsConstructor @Data public static class ROCValue { private final double threshold; private final double truePositiveRate; private final double falsePositiveRate; } @AllArgsConstructor @Data public static class PrecisionRecallPoint { private final double classiferThreshold; private final double precision; private final double recall; } @AllArgsConstructor @Data public static class CountsForThreshold implements Serializable, Cloneable { private double threshold; private long[] countTruePositive; private long[] countFalsePositive; public CountsForThreshold(double threshold, int size) { this(threshold, new long[size], new long[size]); } public void incrementTruePositive(int[] counts) { addInPlace(countTruePositive, counts); } public void incrementFalsePositive(int[] counts) { addInPlace(countFalsePositive, counts); } public void incrementTruePositive(long[] counts) { addInPlace(countTruePositive, counts); } public void incrementFalsePositive(long[] counts) { addInPlace(countFalsePositive, counts); } public void incrementTruePositive(long count, int index) { countTruePositive[index] += count; } public void incrementFalsePositive(long count, int index) { countFalsePositive[index] += count; } @Override public ROCBinary.CountsForThreshold clone() { long[] ctp = ArrayUtils.clone(countTruePositive); long[] tfp = ArrayUtils.clone(countFalsePositive); return new ROCBinary.CountsForThreshold(threshold, ctp, tfp); } } }