/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier; import java.util.Arrays; import java.util.Collection; import java.util.Map; import com.google.common.collect.Lists; import org.apache.mahout.common.MahoutTestCase; import org.apache.mahout.math.Matrix; import org.junit.Test; public final class ConfusionMatrixTest extends MahoutTestCase { private static final int[][] VALUES = {{2, 3}, {10, 20}}; private static final String[] LABELS = {"Label1", "Label2"}; private static final int[] OTHER = {3, 6}; private static final String DEFAULT_LABEL = "other"; @Test public void testBuild() { ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL); checkValues(confusionMatrix); checkAccuracy(confusionMatrix); } @Test public void testGetMatrix() { ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL); Matrix m = confusionMatrix.getMatrix(); Map<String, Integer> rowLabels = m.getRowLabelBindings(); assertEquals(confusionMatrix.getLabels().size(), m.numCols()); assertTrue(rowLabels.keySet().contains(LABELS[0])); assertTrue(rowLabels.keySet().contains(LABELS[1])); assertTrue(rowLabels.keySet().contains(DEFAULT_LABEL)); assertEquals(2, confusionMatrix.getCorrect(LABELS[0])); assertEquals(20, confusionMatrix.getCorrect(LABELS[1])); assertEquals(0, confusionMatrix.getCorrect(DEFAULT_LABEL)); } /** * Example taken from * http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html */ @Test public void testPrecisionRecallAndF1ScoreAsScikitLearn() { Collection<String> labelList = Arrays.asList("0", "1", "2"); ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, "DEFAULT"); confusionMatrix.putCount("0", "0", 2); confusionMatrix.putCount("1", "0", 1); confusionMatrix.putCount("1", "2", 1); confusionMatrix.putCount("2", "1", 2); double delta = 0.001; assertEquals(0.222, confusionMatrix.getWeightedPrecision(), delta); assertEquals(0.333, confusionMatrix.getWeightedRecall(), delta); assertEquals(0.266, confusionMatrix.getWeightedF1score(), delta); } private static void checkValues(ConfusionMatrix cm) { int[][] counts = cm.getConfusionMatrix(); cm.toString(); assertEquals(counts.length, counts[0].length); assertEquals(3, counts.length); assertEquals(VALUES[0][0], counts[0][0]); assertEquals(VALUES[0][1], counts[0][1]); assertEquals(VALUES[1][0], counts[1][0]); assertEquals(VALUES[1][1], counts[1][1]); assertTrue(Arrays.equals(new int[3], counts[2])); // zeros assertEquals(OTHER[0], counts[0][2]); assertEquals(OTHER[1], counts[1][2]); assertEquals(3, cm.getLabels().size()); assertTrue(cm.getLabels().contains(LABELS[0])); assertTrue(cm.getLabels().contains(LABELS[1])); assertTrue(cm.getLabels().contains(DEFAULT_LABEL)); } private static void checkAccuracy(ConfusionMatrix cm) { Collection<String> labelstrs = cm.getLabels(); assertEquals(3, labelstrs.size()); assertEquals(25.0, cm.getAccuracy("Label1"), EPSILON); assertEquals(55.5555555, cm.getAccuracy("Label2"), EPSILON); assertTrue(Double.isNaN(cm.getAccuracy("other"))); } private static ConfusionMatrix fillConfusionMatrix(int[][] values, String[] labels, String defaultLabel) { Collection<String> labelList = Lists.newArrayList(); labelList.add(labels[0]); labelList.add(labels[1]); ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, defaultLabel); confusionMatrix.putCount("Label1", "Label1", values[0][0]); confusionMatrix.putCount("Label1", "Label2", values[0][1]); confusionMatrix.putCount("Label2", "Label1", values[1][0]); confusionMatrix.putCount("Label2", "Label2", values[1][1]); confusionMatrix.putCount("Label1", DEFAULT_LABEL, OTHER[0]); confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER[1]); return confusionMatrix; } }