/*- * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.eval; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; import org.deeplearning4j.eval.meta.Prediction; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.IterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.SplitTestAndTrain; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.util.FeatureUtil; import java.util.*; import static org.junit.Assert.*; /** * Created by agibsonccc on 12/22/14. */ public class EvalTest { @Test public void testEval() { int classNum = 5; Evaluation eval = new Evaluation(classNum); // Testing the edge case when some classes do not have true positive INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 5); //[1,0,0,0,0] INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 5); //[1,0,0,0,0] eval.eval(trueOutcome, predictedOutcome); assertEquals(1, eval.classCount(0)); assertEquals(1.0, eval.f1(), 1e-1); // Testing more than one sample. eval() does not reset the Evaluation instance INDArray trueOutcome2 = FeatureUtil.toOutcomeVector(1, 5); //[0,1,0,0,0] INDArray predictedOutcome2 = FeatureUtil.toOutcomeVector(0, 5); //[1,0,0,0,0] eval.eval(trueOutcome2, predictedOutcome2); // Verified with sklearn in Python // from sklearn.metrics import classification_report // classification_report(['a', 'a'], ['a', 'b'], labels=['a', 'b', 'c', 'd', 'e']) assertEquals(eval.f1(), 0.6, 1e-1); // The first entry is 0 label assertEquals(1, eval.classCount(0)); // The first entry is 1 label assertEquals(1, eval.classCount(1)); // Class 0: one positive, one negative -> (one true positive, one false positive); no true/false negatives assertEquals(1, eval.positive().get(0), 0); assertEquals(1, eval.negative().get(0), 0); assertEquals(1, eval.truePositives().get(0), 0); assertEquals(1, eval.falsePositives().get(0), 0); assertEquals(0, eval.trueNegatives().get(0), 0); assertEquals(0, eval.falseNegatives().get(0), 0); // The rest are negative assertEquals(1, eval.negative().get(0), 0); // 2 rows and only the first is correct assertEquals(0.5, eval.accuracy(), 0); } @Test public void testEval2() { //Confusion matrix: //actual 0 20 3 //actual 1 10 5 Evaluation evaluation = new Evaluation(Arrays.asList("class0", "class1")); INDArray predicted0 = Nd4j.create(new double[] {1, 0}); INDArray predicted1 = Nd4j.create(new double[] {0, 1}); INDArray actual0 = Nd4j.create(new double[] {1, 0}); INDArray actual1 = Nd4j.create(new double[] {0, 1}); for (int i = 0; i < 20; i++) { evaluation.eval(actual0, predicted0); } for (int i = 0; i < 3; i++) { evaluation.eval(actual0, predicted1); } for (int i = 0; i < 10; i++) { evaluation.eval(actual1, predicted0); } for (int i = 0; i < 5; i++) { evaluation.eval(actual1, predicted1); } assertEquals(20, evaluation.truePositives().get(0), 0); assertEquals(3, evaluation.falseNegatives().get(0), 0); assertEquals(10, evaluation.falsePositives().get(0), 0); assertEquals(5, evaluation.trueNegatives().get(0), 0); assertEquals((20.0 + 5) / (20 + 3 + 10 + 5), evaluation.accuracy(), 1e-6); System.out.println(evaluation.confusionToString()); } @Test public void testStringListLabels() { INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2); List<String> labelsList = new ArrayList<>(); labelsList.add("hobbs"); labelsList.add("cal"); Evaluation eval = new Evaluation(labelsList); eval.eval(trueOutcome, predictedOutcome); assertEquals(1, eval.classCount(0)); assertEquals(labelsList.get(0), eval.getClassLabel(0)); } @Test public void testStringHashLabels() { INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2); Map<Integer, String> labelsMap = new HashMap<>(); labelsMap.put(0, "hobbs"); labelsMap.put(1, "cal"); Evaluation eval = new Evaluation(labelsMap); eval.eval(trueOutcome, predictedOutcome); assertEquals(1, eval.classCount(0)); assertEquals(labelsMap.get(0), eval.getClassLabel(0)); } @Test public void testIris() { // Network config MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).iterations(1).seed(42) .learningRate(1e-6).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH) .weightInit(WeightInit.XAVIER).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER) .activation(Activation.SOFTMAX).build()) .build(); // Instantiate model MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(1))); // Train-test split DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSet next = iter.next(); next.shuffle(); SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42)); // Train DataSet train = trainTest.getTrain(); train.normalizeZeroMeanZeroUnitVariance(); // Test DataSet test = trainTest.getTest(); test.normalizeZeroMeanZeroUnitVariance(); INDArray testFeature = test.getFeatureMatrix(); INDArray testLabel = test.getLabels(); // Fitting model model.fit(train); // Get predictions from test feature INDArray testPredictedLabel = model.output(testFeature); // Eval with class number Evaluation eval = new Evaluation(3); //// Specify class num here eval.eval(testLabel, testPredictedLabel); double eval1F1 = eval.f1(); double eval1Acc = eval.accuracy(); // Eval without class number Evaluation eval2 = new Evaluation(); //// No class num eval2.eval(testLabel, testPredictedLabel); double eval2F1 = eval2.f1(); double eval2Acc = eval2.accuracy(); //Assert the two implementations give same f1 and accuracy (since one batch) assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc); Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator(Collections.singletonList(test))); checkEvaluationEquality(eval, evalViaMethod); System.out.println(eval.getConfusionMatrix().toString()); System.out.println(eval.getConfusionMatrix().toCSV()); System.out.println(eval.getConfusionMatrix().toHTML()); System.out.println(eval.confusionToString()); } @Test public void testEvalMasking() { int miniBatch = 5; int nOut = 3; int tsLength = 6; INDArray labels = Nd4j.zeros(miniBatch, nOut, tsLength); INDArray predicted = Nd4j.zeros(miniBatch, nOut, tsLength); Nd4j.getRandom().setSeed(12345); Random r = new Random(12345); for (int i = 0; i < miniBatch; i++) { for (int j = 0; j < tsLength; j++) { INDArray rand = Nd4j.rand(1, nOut); rand.divi(rand.sumNumber()); predicted.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j)}, rand); int idx = r.nextInt(nOut); labels.putScalar(new int[] {i, idx, j}, 1.0); } } //Create a longer labels/predicted with mask for first and last time step //Expect masked evaluation to be identical to original evaluation INDArray labels2 = Nd4j.zeros(miniBatch, nOut, tsLength + 2); labels2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(1, tsLength + 1)}, labels); INDArray predicted2 = Nd4j.zeros(miniBatch, nOut, tsLength + 2); predicted2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(1, tsLength + 1)}, predicted); INDArray labelsMask = Nd4j.ones(miniBatch, tsLength + 2); for (int i = 0; i < miniBatch; i++) { labelsMask.putScalar(new int[] {i, 0}, 0.0); labelsMask.putScalar(new int[] {i, tsLength + 1}, 0.0); } Evaluation evaluation = new Evaluation(); evaluation.evalTimeSeries(labels, predicted); Evaluation evaluation2 = new Evaluation(); evaluation2.evalTimeSeries(labels2, predicted2, labelsMask); System.out.println(evaluation.stats()); System.out.println(evaluation2.stats()); assertEquals(evaluation.accuracy(), evaluation2.accuracy(), 1e-12); assertEquals(evaluation.f1(), evaluation2.f1(), 1e-12); assertMapEquals(evaluation.falsePositives(), evaluation2.falsePositives()); assertMapEquals(evaluation.falseNegatives(), evaluation2.falseNegatives()); assertMapEquals(evaluation.truePositives(), evaluation2.truePositives()); assertMapEquals(evaluation.trueNegatives(), evaluation2.trueNegatives()); for (int i = 0; i < nOut; i++) assertEquals(evaluation.classCount(i), evaluation2.classCount(i)); } private static void assertMapEquals(Map<Integer, Integer> first, Map<Integer, Integer> second) { assertEquals(first.keySet(), second.keySet()); for (Integer i : first.keySet()) { assertEquals(first.get(i), second.get(i)); } } @Test public void testFalsePerfectRecall() { int testSize = 100; int numClasses = 5; int winner = 1; int seed = 241; INDArray labels = Nd4j.zeros(testSize, numClasses); INDArray predicted = Nd4j.zeros(testSize, numClasses); Nd4j.getRandom().setSeed(seed); Random r = new Random(seed); //Modelling the situation when system predicts the same class every time for (int i = 0; i < testSize; i++) { //Generating random prediction but with a guaranteed winner INDArray rand = Nd4j.rand(1, numClasses); rand.put(0, winner, rand.sumNumber()); rand.divi(rand.sumNumber()); predicted.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all()}, rand); //Generating random label int label = r.nextInt(numClasses); labels.putScalar(new int[] {i, label}, 1.0); } //Explicitly specify the amount of classes Evaluation eval = new Evaluation(numClasses); eval.eval(labels, predicted); //For sure we shouldn't arrive at 100% recall unless we guessed everything right for every class assertNotEquals(1.0, eval.recall()); } @Test public void testEvaluationMerging() { int nRows = 20; int nCols = 3; Random r = new Random(12345); INDArray actual = Nd4j.create(nRows, nCols); INDArray predicted = Nd4j.create(nRows, nCols); for (int i = 0; i < nRows; i++) { int x1 = r.nextInt(nCols); int x2 = r.nextInt(nCols); actual.putScalar(new int[] {i, x1}, 1.0); predicted.putScalar(new int[] {i, x2}, 1.0); } Evaluation evalExpected = new Evaluation(); evalExpected.eval(actual, predicted); //Now: split into 3 separate evaluation objects -> expect identical values after merging Evaluation eval1 = new Evaluation(); eval1.eval(actual.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()), predicted.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all())); Evaluation eval2 = new Evaluation(); eval2.eval(actual.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()), predicted.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all())); Evaluation eval3 = new Evaluation(); eval3.eval(actual.get(NDArrayIndex.interval(10, nRows), NDArrayIndex.all()), predicted.get(NDArrayIndex.interval(10, nRows), NDArrayIndex.all())); eval1.merge(eval2); eval1.merge(eval3); checkEvaluationEquality(evalExpected, eval1); //Next: check evaluation merging with empty, and empty merging with non-empty eval1 = new Evaluation(); eval1.eval(actual.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()), predicted.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all())); Evaluation evalInitiallyEmpty = new Evaluation(); evalInitiallyEmpty.merge(eval1); evalInitiallyEmpty.merge(eval2); evalInitiallyEmpty.merge(eval3); checkEvaluationEquality(evalExpected, evalInitiallyEmpty); eval1.merge(new Evaluation()); eval1.merge(eval2); eval1.merge(new Evaluation()); eval1.merge(eval3); checkEvaluationEquality(evalExpected, eval1); } private static void checkEvaluationEquality(Evaluation evalExpected, Evaluation evalActual) { assertEquals(evalExpected.accuracy(), evalActual.accuracy(), 1e-3); assertEquals(evalExpected.f1(), evalActual.f1(), 1e-3); assertEquals(evalExpected.getNumRowCounter(), evalActual.getNumRowCounter(), 1e-3); assertMapEquals(evalExpected.falseNegatives(), evalActual.falseNegatives()); assertMapEquals(evalExpected.falsePositives(), evalActual.falsePositives()); assertMapEquals(evalExpected.trueNegatives(), evalActual.trueNegatives()); assertMapEquals(evalExpected.truePositives(), evalActual.truePositives()); assertEquals(evalExpected.precision(), evalActual.precision(), 1e-3); assertEquals(evalExpected.recall(), evalActual.recall(), 1e-3); assertEquals(evalExpected.falsePositiveRate(), evalActual.falsePositiveRate(), 1e-3); assertEquals(evalExpected.falseNegativeRate(), evalActual.falseNegativeRate(), 1e-3); assertEquals(evalExpected.falseAlarmRate(), evalActual.falseAlarmRate(), 1e-3); assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix()); } @Test public void testSingleClassBinaryClassification() { Evaluation eval = new Evaluation(1); for (int xe = 0; xe < 3; xe++) { INDArray zero = Nd4j.create(1); INDArray one = Nd4j.ones(1); //One incorrect, three correct eval.eval(one, zero); eval.eval(one, one); eval.eval(one, one); eval.eval(zero, zero); System.out.println(eval.stats()); assertEquals(0.75, eval.accuracy(), 1e-6); assertEquals(4, eval.getNumRowCounter()); assertEquals(1, (int) eval.truePositives().get(0)); assertEquals(2, (int) eval.truePositives().get(1)); assertEquals(1, (int) eval.falseNegatives().get(1)); eval.reset(); } } @Test public void testEvalInvalid() { Evaluation e = new Evaluation(5); e.eval(0, 1); e.eval(1, 0); e.eval(1, 1); System.out.println(e.stats()); char c = "\uFFFD".toCharArray()[0]; System.out.println(c); assertFalse(e.stats().contains("\uFFFD")); } @Test public void testEvalMethods() { //Check eval(int,int) vs. eval(INDArray,INDArray) Evaluation e1 = new Evaluation(4); Evaluation e2 = new Evaluation(4); INDArray i0 = Nd4j.create(new double[] {1, 0, 0, 0}); INDArray i1 = Nd4j.create(new double[] {0, 1, 0, 0}); INDArray i2 = Nd4j.create(new double[] {0, 0, 1, 0}); INDArray i3 = Nd4j.create(new double[] {0, 0, 0, 1}); e1.eval(i0, i0); //order: actual, predicted e2.eval(0, 0); //order: predicted, actual e1.eval(i0, i2); e2.eval(2, 0); e1.eval(i0, i2); e2.eval(2, 0); e1.eval(i1, i2); e2.eval(2, 1); e1.eval(i3, i3); e2.eval(3, 3); e1.eval(i3, i0); e2.eval(0, 3); e1.eval(i3, i0); e2.eval(0, 3); ConfusionMatrix<Integer> cm = e1.getConfusionMatrix(); assertEquals(1, cm.getCount(0, 0)); //Order: actual, predicted assertEquals(2, cm.getCount(0, 2)); assertEquals(1, cm.getCount(1, 2)); assertEquals(1, cm.getCount(3, 3)); assertEquals(2, cm.getCount(3, 0)); System.out.println(e1.stats()); System.out.println(e2.stats()); assertEquals(e1.stats(), e2.stats()); } @Test public void testTopNAccuracy() { Evaluation e = new Evaluation(null, 3); INDArray i0 = Nd4j.create(new double[] {1, 0, 0, 0, 0}); INDArray i1 = Nd4j.create(new double[] {0, 1, 0, 0, 0}); INDArray p0_0 = Nd4j.create(new double[] {0.8, 0.05, 0.05, 0.05, 0.05}); //class 0: highest prob INDArray p0_1 = Nd4j.create(new double[] {0.4, 0.45, 0.05, 0.05, 0.05}); //class 0: 2nd highest prob INDArray p0_2 = Nd4j.create(new double[] {0.1, 0.45, 0.35, 0.05, 0.05}); //class 0: 3rd highest prob INDArray p0_3 = Nd4j.create(new double[] {0.1, 0.40, 0.30, 0.15, 0.05}); //class 0: 4th highest prob INDArray p1_0 = Nd4j.create(new double[] {0.05, 0.80, 0.05, 0.05, 0.05}); //class 1: highest prob INDArray p1_1 = Nd4j.create(new double[] {0.45, 0.40, 0.05, 0.05, 0.05}); //class 1: 2nd highest prob INDArray p1_2 = Nd4j.create(new double[] {0.35, 0.10, 0.45, 0.05, 0.05}); //class 1: 3rd highest prob INDArray p1_3 = Nd4j.create(new double[] {0.40, 0.10, 0.30, 0.15, 0.05}); //class 1: 4th highest prob // Correct TopNCorrect Total e.eval(i0, p0_0); // 1 1 1 assertEquals(1.0, e.accuracy(), 1e-6); assertEquals(1.0, e.topNAccuracy(), 1e-6); assertEquals(1, e.getTopNCorrectCount()); assertEquals(1, e.getTopNTotalCount()); e.eval(i0, p0_1); // 1 2 2 assertEquals(0.5, e.accuracy(), 1e-6); assertEquals(1.0, e.topNAccuracy(), 1e-6); assertEquals(2, e.getTopNCorrectCount()); assertEquals(2, e.getTopNTotalCount()); e.eval(i0, p0_2); // 1 3 3 assertEquals(1.0 / 3, e.accuracy(), 1e-6); assertEquals(1.0, e.topNAccuracy(), 1e-6); assertEquals(3, e.getTopNCorrectCount()); assertEquals(3, e.getTopNTotalCount()); e.eval(i0, p0_3); // 1 3 4 assertEquals(0.25, e.accuracy(), 1e-6); assertEquals(0.75, e.topNAccuracy(), 1e-6); assertEquals(3, e.getTopNCorrectCount()); assertEquals(4, e.getTopNTotalCount()); e.eval(i1, p1_0); // 2 4 5 assertEquals(2.0 / 5, e.accuracy(), 1e-6); assertEquals(4.0 / 5, e.topNAccuracy(), 1e-6); e.eval(i1, p1_1); // 2 5 6 assertEquals(2.0 / 6, e.accuracy(), 1e-6); assertEquals(5.0 / 6, e.topNAccuracy(), 1e-6); e.eval(i1, p1_2); // 2 6 7 assertEquals(2.0 / 7, e.accuracy(), 1e-6); assertEquals(6.0 / 7, e.topNAccuracy(), 1e-6); e.eval(i1, p1_3); // 2 6 8 assertEquals(2.0 / 8, e.accuracy(), 1e-6); assertEquals(6.0 / 8, e.topNAccuracy(), 1e-6); assertEquals(6, e.getTopNCorrectCount()); assertEquals(8, e.getTopNTotalCount()); System.out.println(e.stats()); } @Test public void testTopNAccuracyMerging() { Evaluation e1 = new Evaluation(null, 3); Evaluation e2 = new Evaluation(null, 3); INDArray i0 = Nd4j.create(new double[] {1, 0, 0, 0, 0}); INDArray i1 = Nd4j.create(new double[] {0, 1, 0, 0, 0}); INDArray p0_0 = Nd4j.create(new double[] {0.8, 0.05, 0.05, 0.05, 0.05}); //class 0: highest prob INDArray p0_1 = Nd4j.create(new double[] {0.4, 0.45, 0.05, 0.05, 0.05}); //class 0: 2nd highest prob INDArray p0_2 = Nd4j.create(new double[] {0.1, 0.45, 0.35, 0.05, 0.05}); //class 0: 3rd highest prob INDArray p0_3 = Nd4j.create(new double[] {0.1, 0.40, 0.30, 0.15, 0.05}); //class 0: 4th highest prob INDArray p1_0 = Nd4j.create(new double[] {0.05, 0.80, 0.05, 0.05, 0.05}); //class 1: highest prob INDArray p1_1 = Nd4j.create(new double[] {0.45, 0.40, 0.05, 0.05, 0.05}); //class 1: 2nd highest prob INDArray p1_2 = Nd4j.create(new double[] {0.35, 0.10, 0.45, 0.05, 0.05}); //class 1: 3rd highest prob INDArray p1_3 = Nd4j.create(new double[] {0.40, 0.10, 0.30, 0.15, 0.05}); //class 1: 4th highest prob // Correct TopNCorrect Total e1.eval(i0, p0_0); // 1 1 1 e1.eval(i0, p0_1); // 1 2 2 e1.eval(i0, p0_2); // 1 3 3 e1.eval(i0, p0_3); // 1 3 4 assertEquals(0.25, e1.accuracy(), 1e-6); assertEquals(0.75, e1.topNAccuracy(), 1e-6); assertEquals(3, e1.getTopNCorrectCount()); assertEquals(4, e1.getTopNTotalCount()); e2.eval(i1, p1_0); // 1 1 1 e2.eval(i1, p1_1); // 1 2 2 e2.eval(i1, p1_2); // 1 3 3 e2.eval(i1, p1_3); // 1 3 4 assertEquals(1.0 / 4, e2.accuracy(), 1e-6); assertEquals(3.0 / 4, e2.topNAccuracy(), 1e-6); assertEquals(3, e2.getTopNCorrectCount()); assertEquals(4, e2.getTopNTotalCount()); e1.merge(e2); assertEquals(8, e1.getNumRowCounter()); assertEquals(8, e1.getTopNTotalCount()); assertEquals(6, e1.getTopNCorrectCount()); assertEquals(2.0 / 8, e1.accuracy(), 1e-6); assertEquals(6.0 / 8, e1.topNAccuracy(), 1e-6); } @Test public void testEvaluationWithMetaData() throws Exception { RecordReader csv = new CSVRecordReader(); csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive())); int batchSize = 10; int labelIdx = 4; int numClasses = 3; RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses); NormalizerStandardize ns = new NormalizerStandardize(); ns.fit(rrdsi); rrdsi.setPreProcessor(ns); rrdsi.reset(); Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).iterations(1) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD) .learningRate(0.1).list() .layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(4).nOut(3).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); for (int i = 0; i < 4; i++) { net.fit(rrdsi); rrdsi.reset(); } Evaluation e = new Evaluation(); rrdsi.setCollectMetaData(true); //*** New: Enable collection of metadata (stored in the DataSets) *** while (rrdsi.hasNext()) { DataSet ds = rrdsi.next(); List<RecordMetaData> meta = ds.getExampleMetaData(RecordMetaData.class); //*** New - cross dependencies here make types difficult, usid Object internally in DataSet for this*** INDArray out = net.output(ds.getFeatures()); e.eval(ds.getLabels(), out, meta); //*** New - evaluate and also store metadata *** } System.out.println(e.stats()); System.out.println("\n\n*** Prediction Errors: ***"); List<Prediction> errors = e.getPredictionErrors(); //*** New - get list of prediction errors from evaluation *** List<RecordMetaData> metaForErrors = new ArrayList<>(); for (Prediction p : errors) { metaForErrors.add((RecordMetaData) p.getRecordMetaData()); } DataSet ds = rrdsi.loadFromMetaData(metaForErrors); //*** New - dynamically load a subset of the data, just for prediction errors *** INDArray output = net.output(ds.getFeatures()); int count = 0; for (Prediction t : errors) { System.out.println(t + "\t\tRaw Data: " + csv.loadFromMetaData((RecordMetaData) t.getRecordMetaData()).getRecord() //*** New - load subset of data from MetaData object (usually batched for efficiency) *** + "\tNormalized: " + ds.getFeatureMatrix().getRow(count) + "\tLabels: " + ds.getLabels().getRow(count) + "\tNetwork predictions: " + output.getRow(count)); count++; } int errorCount = errors.size(); double expAcc = 1.0 - errorCount / 150.0; assertEquals(expAcc, e.accuracy(), 1e-5); ConfusionMatrix<Integer> confusion = e.getConfusionMatrix(); int[] actualCounts = new int[3]; int[] predictedCounts = new int[3]; for (int i = 0; i < 3; i++) { for (int j = 0; j < 3; j++) { int entry = confusion.getCount(i, j); //(actual,predicted) List<Prediction> list = e.getPredictions(i, j); assertEquals(entry, list.size()); actualCounts[i] += entry; predictedCounts[j] += entry; } } for (int i = 0; i < 3; i++) { List<Prediction> actualClassI = e.getPredictionsByActualClass(i); List<Prediction> predictedClassI = e.getPredictionByPredictedClass(i); assertEquals(actualCounts[i], actualClassI.size()); assertEquals(predictedCounts[i], predictedClassI.size()); } } @Test public void testBinaryCase() { INDArray ones10 = Nd4j.ones(10, 1); INDArray ones4 = Nd4j.ones(4, 1); INDArray zeros4 = Nd4j.zeros(4, 1); INDArray ones3 = Nd4j.ones(3, 1); INDArray zeros3 = Nd4j.zeros(3, 1); INDArray zeros2 = Nd4j.zeros(2, 1); Evaluation e = new Evaluation(); e.eval(ones10, ones10); //10 true positives e.eval(ones3, zeros3); //3 false negatives e.eval(zeros4, ones4); //4 false positives e.eval(zeros2, zeros2); //2 true negatives assertEquals((10 + 2) / (double) (10 + 3 + 4 + 2), e.accuracy(), 1e-6); assertEquals(10, (int) e.truePositives().get(1)); assertEquals(3, (int) e.falseNegatives().get(1)); assertEquals(4, (int) e.falsePositives().get(1)); assertEquals(2, (int) e.trueNegatives().get(1)); //If we switch the label around: tp becomes tn, fp becomes fn, etc assertEquals(10, (int) e.trueNegatives().get(0)); assertEquals(3, (int) e.falsePositives().get(0)); assertEquals(4, (int) e.falseNegatives().get(0)); assertEquals(2, (int) e.truePositives().get(0)); } @Test public void testF1FBeta_MicroMacroAveraging(){ //Confusion matrix: rows = actual, columns = predicted //[3, 1, 0] //[2, 2, 1] //[0, 3, 4] INDArray zero = Nd4j.create(new double[]{1,0,0}); INDArray one = Nd4j.create(new double[]{0,1,0}); INDArray two = Nd4j.create(new double[]{0,0,1}); Evaluation e = new Evaluation(); apply(e, 3, zero, zero); apply(e, 1, one, zero); apply(e, 2, zero, one); apply(e, 2, one, one); apply(e, 1, two, one); apply(e, 3, one, two); apply(e, 4, two, two); assertEquals(3, e.confusion.getCount(0,0)); assertEquals(1, e.confusion.getCount(0,1)); assertEquals(0, e.confusion.getCount(0,2)); assertEquals(2, e.confusion.getCount(1,0)); assertEquals(2, e.confusion.getCount(1,1)); assertEquals(1, e.confusion.getCount(1,2)); assertEquals(0, e.confusion.getCount(2,0)); assertEquals(3, e.confusion.getCount(2,1)); assertEquals(4, e.confusion.getCount(2,2)); double beta = 3.5; double[] prec = new double[3]; double[] rec = new double[3]; for( int i=0; i<3; i++ ){ prec[i] = e.truePositives().get(i) / (double)(e.truePositives().get(i) + e.falsePositives().get(i)); rec[i] = e.truePositives().get(i) / (double)(e.truePositives().get(i) + e.falseNegatives().get(i)); } //Binarized confusion //class 0: // [3, 1] [tp fn] // [2, 10] [fp tn] assertEquals(3, (int)e.truePositives().get(0)); assertEquals(1, (int)e.falseNegatives().get(0)); assertEquals(2, (int)e.falsePositives().get(0)); assertEquals(10, (int)e.trueNegatives().get(0)); //class 1: // [2, 3] [tp fn] // [4, 7] [fp tn] assertEquals(2, (int)e.truePositives().get(1)); assertEquals(3, (int)e.falseNegatives().get(1)); assertEquals(4, (int)e.falsePositives().get(1)); assertEquals(7, (int)e.trueNegatives().get(1)); //class 2: // [4, 3] [tp fn] // [1, 8] [fp tn] assertEquals(4, (int)e.truePositives().get(2)); assertEquals(3, (int)e.falseNegatives().get(2)); assertEquals(1, (int)e.falsePositives().get(2)); assertEquals(8, (int)e.trueNegatives().get(2)); double[] fBeta = new double[3]; double[] f1 = new double[3]; double[] mcc = new double[3]; for( int i=0; i<3; i++ ){ fBeta[i] = (1+beta*beta)*prec[i]*rec[i] / (beta*beta*prec[i] + rec[i]); f1[i] = 2*prec[i]*rec[i] / (prec[i] + rec[i]); assertEquals(fBeta[i], e.fBeta(beta, i), 1e-6); assertEquals(f1[i], e.f1(i), 1e-6); double gmeasure = Math.sqrt(prec[i] * rec[i]); assertEquals(gmeasure, e.gMeasure(i), 1e-6); double tp = e.truePositives().get(i); double tn = e.trueNegatives().get(i); double fp = e.falsePositives().get(i); double fn = e.falseNegatives().get(i); mcc[i] = (tp*tn - fp*fn) / Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)); assertEquals(mcc[i], e.matthewsCorrelation(i), 1e-6); } //Test macro and micro averaging: int tp = 0; int fn = 0; int fp = 0; int tn = 0; double macroPrecision = 0.0; double macroRecall = 0.0; double macroF1 = 0.0; double macroFBeta = 0.0; double macroMcc = 0.0; for( int i=0; i<3; i++ ){ tp += e.truePositives().get(i); fn += e.falseNegatives().get(i); fp += e.falsePositives().get(i); tn += e.trueNegatives().get(i); macroPrecision += prec[i]; macroRecall += rec[i]; macroF1 += f1[i]; macroFBeta += fBeta[i]; macroMcc += mcc[i]; } macroPrecision /= 3; macroRecall /= 3; macroF1 /= 3; macroFBeta /= 3; macroMcc /= 3; double microPrecision = tp / (double)(tp + fp); double microRecall = tp / (double)(tp + fn); double microFBeta = (1+beta*beta)*microPrecision*microRecall / (beta*beta*microPrecision + microRecall); double microF1 = 2*microPrecision*microRecall / (microPrecision + microRecall); double microMcc = (tp*tn - fp*fn) / Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)); assertEquals(microPrecision, e.precision(EvaluationAveraging.Micro), 1e-6); assertEquals(microRecall, e.recall(EvaluationAveraging.Micro), 1e-6); assertEquals(macroPrecision, e.precision(EvaluationAveraging.Macro), 1e-6); assertEquals(macroRecall, e.recall(EvaluationAveraging.Macro), 1e-6); assertEquals(microFBeta, e.fBeta(beta, EvaluationAveraging.Micro), 1e-6); assertEquals(macroFBeta, e.fBeta(beta, EvaluationAveraging.Macro), 1e-6); assertEquals(microF1, e.f1(EvaluationAveraging.Micro), 1e-6); assertEquals(macroF1, e.f1(EvaluationAveraging.Macro), 1e-6); assertEquals(microMcc, e.matthewsCorrelation(EvaluationAveraging.Micro), 1e-6); assertEquals(macroMcc, e.matthewsCorrelation(EvaluationAveraging.Macro), 1e-6); } private static void apply(Evaluation e, int nTimes, INDArray predicted, INDArray actual){ for( int i=0; i<nTimes; i++ ){ e.eval(actual, predicted); } } }