package org.deeplearning4j.eval; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; /** * Created by Alex on 20/03/2017. */ public class EvaluationBinaryTest { @Test public void testEvaluationBinary() { //Compare EvaluationBinary to Evaluation class Nd4j.getRandom().setSeed(12345); int nExamples = 50; int nOut = 4; int[] shape = {nExamples, nOut}; INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape), 0.5)); INDArray predicted = Nd4j.rand(shape); INDArray binaryPredicted = predicted.gt(0.5); EvaluationBinary eb = new EvaluationBinary(); eb.eval(labels, predicted); System.out.println(eb.stats()); double eps = 1e-6; for (int i = 0; i < nOut; i++) { INDArray lCol = labels.getColumn(i); INDArray pCol = predicted.getColumn(i); INDArray bpCol = binaryPredicted.getColumn(i); int countCorrect = 0; int tpCount = 0; int tnCount = 0; for (int j = 0; j < lCol.length(); j++) { if (lCol.getDouble(j) == bpCol.getDouble(j)) { countCorrect++; if (lCol.getDouble(j) == 1) { tpCount++; } else { tnCount++; } } } double acc = countCorrect / (double) lCol.length(); Evaluation e = new Evaluation(); e.eval(lCol, pCol); assertEquals(acc, eb.accuracy(i), eps); assertEquals(e.accuracy(), eb.accuracy(i), eps); assertEquals(e.precision(1), eb.precision(i), eps); assertEquals(e.recall(1), eb.recall(i), eps); assertEquals(e.f1(1), eb.f1(i), eps); assertEquals(tpCount, eb.truePositives(i)); assertEquals(tnCount, eb.trueNegatives(i)); assertEquals((int) e.truePositives().get(1), eb.truePositives(i)); assertEquals((int) e.trueNegatives().get(1), eb.trueNegatives(i)); assertEquals((int) e.falsePositives().get(1), eb.falsePositives(i)); assertEquals((int) e.falseNegatives().get(1), eb.falseNegatives(i)); assertEquals(nExamples, eb.totalCount(i)); } } @Test public void testEvaluationBinaryMerging() { int nOut = 4; int[] shape1 = {30, nOut}; int[] shape2 = {50, nOut}; Nd4j.getRandom().setSeed(12345); INDArray l1 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape1), 0.5)); INDArray l2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape2), 0.5)); INDArray p1 = Nd4j.rand(shape1); INDArray p2 = Nd4j.rand(shape2); EvaluationBinary eb = new EvaluationBinary(); eb.eval(l1, p1); eb.eval(l2, p2); EvaluationBinary eb1 = new EvaluationBinary(); eb1.eval(l1, p1); EvaluationBinary eb2 = new EvaluationBinary(); eb2.eval(l2, p2); eb1.merge(eb2); assertEquals(eb.stats(), eb1.stats()); } @Test public void testEvaluationBinaryPerOutputMasking() { //Provide a mask array: "ignore" the masked steps INDArray mask = Nd4j.create(new double[][] {{1, 1, 0}, {1, 0, 0}, {1, 1, 0}, {1, 0, 0}, {1, 1, 1}}); INDArray labels = Nd4j.create(new double[][] {{1, 1, 1}, {0, 0, 0}, {1, 1, 1}, {0, 1, 1}, {1, 0, 1}}); INDArray predicted = Nd4j.create(new double[][] {{0.9, 0.9, 0.9}, {0.7, 0.7, 0.7}, {0.6, 0.6, 0.6}, {0.4, 0.4, 0.4}, {0.1, 0.1, 0.1}}); //Correct? // Y Y m // N m m // Y Y m // Y m m // N Y N EvaluationBinary eb = new EvaluationBinary(); eb.eval(labels, predicted, mask); assertEquals(0.6, eb.accuracy(0), 1e-6); assertEquals(1.0, eb.accuracy(1), 1e-6); assertEquals(0.0, eb.accuracy(2), 1e-6); assertEquals(2, eb.truePositives(0)); assertEquals(2, eb.truePositives(1)); assertEquals(0, eb.truePositives(2)); assertEquals(1, eb.trueNegatives(0)); assertEquals(1, eb.trueNegatives(1)); assertEquals(0, eb.trueNegatives(2)); assertEquals(1, eb.falsePositives(0)); assertEquals(0, eb.falsePositives(1)); assertEquals(0, eb.falsePositives(2)); assertEquals(1, eb.falseNegatives(0)); assertEquals(0, eb.falseNegatives(1)); assertEquals(1, eb.falseNegatives(2)); } @Test public void testTimeSeriesEval() { int[] shape = {2, 4, 3}; Nd4j.getRandom().setSeed(12345); INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape), 0.5)); INDArray predicted = Nd4j.rand(shape); INDArray mask = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape), 0.5)); EvaluationBinary eb1 = new EvaluationBinary(); eb1.eval(labels, predicted, mask); EvaluationBinary eb2 = new EvaluationBinary(); for (int i = 0; i < shape[2]; i++) { INDArray l = labels.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i)); INDArray p = predicted.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i)); INDArray m = mask.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i)); eb2.eval(l, p, m); } assertEquals(eb2.stats(), eb1.stats()); } @Test public void testEvaluationBinaryWithROC() { //Simple test for nested ROCBinary in EvaluationBinary Nd4j.getRandom().setSeed(12345); INDArray l1 = Nd4j.getExecutioner() .exec(new BernoulliDistribution(Nd4j.createUninitialized(new int[] {50, 4}), 0.5)); INDArray p1 = Nd4j.rand(50, 4); EvaluationBinary eb = new EvaluationBinary(4, 30); eb.eval(l1, p1); System.out.println(eb.stats()); } }