/* * Carrot2 project. * * Copyright (C) 2002-2016, Dawid Weiss, Stanisław Osiński. * All rights reserved. * * Refer to the full license file "carrot2.LICENSE" * in the root folder of the repository checkout or at: * http://www.carrot2.org/carrot2.LICENSE */ package org.carrot2.output.metrics; import org.carrot2.core.Cluster; import org.carrot2.util.MathUtils; import org.fest.assertions.Delta; import org.junit.Test; import org.carrot2.shaded.guava.common.collect.Lists; /** * Test cases for {@link IClusteringMetric}. */ public class PrecisionRecallMetricTest extends IdealPartitioningBasedMetricTest { @Test public void testEmptyCluster() { check(null, null, null, new Cluster()); } @Test public void testTrivialCluster() { check(1.0, 1.0, 1.0, new Cluster("test", documentWithPartitions("test"))); } @Test public void testPartiallyContaminatedCluster() { check((3 * 0.75 + 1 * 0.25) / 4, 1.0, (3 * MathUtils.harmonicMean(0.75, 1.0) + 1 * MathUtils .harmonicMean(0.25, 1.0)) / 4, partiallyContaminatedCluster()); } @Test public void testFullyContaminatedCluster() { check(0.25, 1.0, 2 * 0.25 / (1 + 0.25), fullyContaminatedCluster()); } @Test public void testPureCluster() { check(1.0, 1.0, 1.0, pureCluster()); } @Test public void testHardClustersWithOverlappingPartitions() { check(1.0, MathUtils.arithmeticMean(2.0 / 3.0, 1, 3, 2), MathUtils .arithmeticMean(MathUtils.harmonicMean(2.0 / 3.0, 1), 1, 3, 2), hardClustersWithOverlappingPartitions()); } @Test public void testHardPartitionsOverlappingClusters() { check(MathUtils.arithmeticMean(2.0 / 3.0, 1, 2, 2), 1.0, MathUtils .arithmeticMean(MathUtils.harmonicMean(2.0 / 3.0, 1), 1, 2, 2), overlappingClustersWithHardPartitions()); } @Test public void testOverlappingPartitionsOverlappingClusters() { check(1.0, 1.0, 1.0, overlappingClustersWithOverlappingPartitions()); } @Test public void testAllDocumentsInOtherTopics() { final Cluster otherTopics = clusterWithPartitions("t1", "t2", "t3"); otherTopics.setOtherTopics(true); check(0.0, 0.0, 0.0, otherTopics); } @Test public void testIdealClustering() { check(1.0, 1.0, 1.0, idealClusters()); } private void check(Double expectedAveragePrecision, Double expectedAverageRecall, Double expectedAverageFMeasure, Cluster... clusters) { final PrecisionRecallMetric metric = new PrecisionRecallMetric(); metric.documents = getAllDocuments(clusters); metric.clusters = Lists.newArrayList(clusters); metric.calculate(); assertEquals(expectedAveragePrecision, metric.weightedAveragePrecision, 0.001, "precision"); assertEquals(expectedAverageRecall, metric.weightedAverageRecall, 0.001, "recall"); assertEquals(expectedAverageFMeasure, metric.weightedAverageFMeasure, 0.001, "f-measure"); } private static void assertEquals(Double expected, Double actual, double delta, String as) { if (expected != null) { assertThat(actual).as(as).isEqualTo(expected, Delta.delta(delta)); } else { assertThat((Object) actual).as(as).isEqualTo(expected); } } @Override protected String [] getClusterMetricKeys() { return new String [] { PrecisionRecallMetric.BEST_F_MEASURE_PARTITION }; } }