/* * Carrot2 project. * * Copyright (C) 2002-2016, Dawid Weiss, Stanisław Osiński. * All rights reserved. * * Refer to the full license file "carrot2.LICENSE" * in the root folder of the repository checkout or at: * http://www.carrot2.org/carrot2.LICENSE */ package org.carrot2.output.metrics; import java.util.*; import org.carrot2.core.Cluster; import org.carrot2.core.Document; import org.carrot2.core.attribute.AttributeNames; import org.carrot2.core.attribute.Processing; import org.carrot2.util.attribute.*; import org.carrot2.shaded.guava.common.collect.Lists; /** * Computes cluster contamination. If a cluster groups documents found in the same * {@link Document#PARTITIONS}, its contamination is 0. If a cluster groups an equally * distributed mix of all partitions, its contamination is 1.0. For a full definition, * please see section 4.4.1 of <a * href="http://project.carrot2.org/publications/osinski04-dimensionality.pdf">this * work</a>. * <p> * Contamination is calculated for top-level clusters only, taking into account documents * from the cluster and all subclusters. Finally, contamination will be calculated only if * all input documents have non-blank {@link Document#PARTITIONS}s. * </p> */ @Bindable public class ContaminationMetric extends IdealPartitioningBasedMetric { /** * Key for the contamination value of a cluster. */ public static final String CONTAMINATION = "contamination"; /** * Average contamination of the whole cluster set, weighted by the size of cluster. */ @Processing @Output @Attribute public double weightedAverageContamination; /** * Calculate contamination metric. */ @Processing @Input @Attribute public boolean enabled = true; @Processing @Input @Attribute(key = AttributeNames.DOCUMENTS) public List<Document> documents; @Processing @Input @Attribute(key = AttributeNames.CLUSTERS) public List<Cluster> clusters; public void calculate() { final int partitionCount = getPartitionsCount(documents); if (partitionCount == 0) { return; } int weightSum = 0; double contaminationSum = 0; for (Cluster cluster : clusters) { if (cluster.isOtherTopics()) { continue; } final double contamination = calculate(cluster, partitionCount); cluster.setAttribute(CONTAMINATION, contamination); contaminationSum += contamination * cluster.size(); weightSum += cluster.size(); } weightedAverageContamination = contaminationSum / weightSum; } @SuppressWarnings("unchecked") double calculate(Cluster cluster, int partitionCount) { int clusterPartitionAssignments = 0; for (Document document : cluster.getAllDocuments()) { clusterPartitionAssignments += ((Collection<Object>) document .getField(Document.PARTITIONS)).size(); } final double worstCaseH = calculateWorstCaseH(clusterPartitionAssignments, partitionCount); if (worstCaseH == 0) { return 0; } else { return calculateH(cluster) / worstCaseH; } } int calculateH(Cluster cluster) { final Map<Object, Integer> documentCountByPartition = getDocumentCountByPartition(cluster .getAllDocuments()); final ArrayList<Integer> counts = Lists.newArrayList(); counts.addAll(documentCountByPartition.values()); return calculateH(counts); } static int calculateWorstCaseH(int clusterSize, int partitionCount) { final ArrayList<Integer> counts = Lists.newArrayList(); for (int partition = 0; partition < partitionCount; partition++) { counts.add(clusterSize / partitionCount + (partition < (clusterSize % partitionCount) ? 1 : 0)); } return calculateH(counts); } static int calculateH(final ArrayList<Integer> counts) { int h = 0; for (int i = 0; i < counts.size() - 1; i++) { for (int j = i + 1; j < counts.size(); j++) { h += counts.get(i) * counts.get(j); } } return h; } public boolean isEnabled() { return enabled; } }