package org.deeplearning4j.eval; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; import java.util.List; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; /** * Created by Alex on 21/03/2017. */ public class ROCBinaryTest { @Test public void testROCBinary() { //Compare ROCBinary to ROC class Nd4j.getRandom().setSeed(12345); int nExamples = 50; int nOut = 4; int[] shape = {nExamples, nOut}; int thresholdSteps = 30; INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape), 0.5)); INDArray predicted = Nd4j.rand(shape); INDArray binaryPredicted = predicted.gt(0.5); ROCBinary rb = new ROCBinary(thresholdSteps); for (int xe = 0; xe < 2; xe++) { rb.eval(labels, predicted); System.out.println(rb.stats()); double eps = 1e-6; for (int i = 0; i < nOut; i++) { INDArray lCol = labels.getColumn(i); INDArray pCol = predicted.getColumn(i); ROC r = new ROC(thresholdSteps); r.eval(lCol, pCol); double aucExp = r.calculateAUC(); double auc = rb.calculateAUC(i); assertEquals(aucExp, auc, eps); long apExp = r.getCountActualPositive(); long ap = rb.getCountActualPositive(i); assertEquals(ap, apExp); long anExp = r.getCountActualNegative(); long an = rb.getCountActualNegative(i); assertEquals(anExp, an); List<ROC.PrecisionRecallPoint> pExp = r.getPrecisionRecallCurve(); List<ROCBinary.PrecisionRecallPoint> p = rb.getPrecisionRecallCurve(i); assertEquals(pExp.size(), p.size()); for (int j = 0; j < pExp.size(); j++) { ROC.PrecisionRecallPoint a = pExp.get(j); ROCBinary.PrecisionRecallPoint b = p.get(j); assertEquals(a.getClassiferThreshold(), b.getClassiferThreshold(), eps); assertEquals(a.getPrecision(), b.getPrecision(), eps); assertEquals(a.getRecall(), b.getRecall(), eps); } double[][] d1 = r.getResultsAsArray(); double[][] d2 = rb.getResultsAsArray(i); assertEquals(d1.length, d2.length); for (int j = 0; j < d1.length; j++) { assertArrayEquals(d1[j], d2[j], eps); } } rb.reset(); } } @Test public void testRocBinaryMerging() { int nSteps = 30; int nOut = 4; int[] shape1 = {30, nOut}; int[] shape2 = {50, nOut}; Nd4j.getRandom().setSeed(12345); INDArray l1 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape1), 0.5)); INDArray l2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape2), 0.5)); INDArray p1 = Nd4j.rand(shape1); INDArray p2 = Nd4j.rand(shape2); ROCBinary rb = new ROCBinary(nSteps); rb.eval(l1, p1); rb.eval(l2, p2); ROCBinary rb1 = new ROCBinary(nSteps); rb1.eval(l1, p1); ROCBinary rb2 = new ROCBinary(nSteps); rb2.eval(l2, p2); rb1.merge(rb2); assertEquals(rb.stats(), rb1.stats()); } @Test public void testROCBinaryPerOutputMasking() { int nSteps = 30; //Here: we'll create a test array, then insert some 'masked out' values, and ensure we get the same results INDArray mask = Nd4j.create(new double[][] {{1, 1, 1}, {0, 1, 1}, {1, 0, 1}, {1, 1, 0}, {1, 1, 1}}); INDArray labels = Nd4j.create(new double[][] {{0, 1, 0}, {1, 1, 0}, {0, 1, 1}, {0, 0, 1}, {1, 1, 1}}); //Remove the 1 masked value for each column INDArray labelsExMasked = Nd4j.create(new double[][] {{0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {1, 1, 1}}); INDArray predicted = Nd4j.create(new double[][] {{0.9, 0.4, 0.6}, {0.2, 0.8, 0.4}, {0.6, 0.1, 0.1}, {0.3, 0.7, 0.2}, {0.8, 0.6, 0.6}}); INDArray predictedExMasked = Nd4j .create(new double[][] {{0.9, 0.4, 0.6}, {0.6, 0.8, 0.4}, {0.3, 0.7, 0.1}, {0.8, 0.6, 0.6}}); ROCBinary rbMasked = new ROCBinary(nSteps); rbMasked.eval(labels, predicted, mask); ROCBinary rb = new ROCBinary(nSteps); rb.eval(labelsExMasked, predictedExMasked); assertEquals(rb.stats(), rbMasked.stats()); for (int i = 0; i < 3; i++) { List<ROCBinary.PrecisionRecallPoint> pExp = rb.getPrecisionRecallCurve(i); List<ROCBinary.PrecisionRecallPoint> p = rbMasked.getPrecisionRecallCurve(i); assertEquals(pExp, p); } } }