/**
* This file is part of General Entity Annotator Benchmark.
*
* General Entity Annotator Benchmark is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* General Entity Annotator Benchmark is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with General Entity Annotator Benchmark. If not, see <http://www.gnu.org/licenses/>.
*/
package org.aksw.gerbil.matching.impl;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.aksw.gerbil.matching.EvaluationCounts;
import org.aksw.gerbil.matching.MatchingsSearcher;
import org.aksw.gerbil.semantic.kb.UriKBClassifier;
import org.aksw.gerbil.semantic.subclass.ClassNode;
import org.aksw.gerbil.semantic.subclass.ClassSet;
import org.aksw.gerbil.semantic.subclass.ClassifiedClassNode;
import org.aksw.gerbil.semantic.subclass.ClassifyingClassNodeFactory;
import org.aksw.gerbil.semantic.subclass.SimpleClassSet;
import org.aksw.gerbil.semantic.subclass.SubClassInferencer;
import org.aksw.gerbil.transfer.nif.TypedMarking;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.carrotsearch.hppc.BitSet;
public class HierarchicalMatchingsCounter<T extends TypedMarking> {
private static final Logger LOGGER = LoggerFactory.getLogger(HierarchicalMatchingsCounter.class);
private static final int EXPECTED_CLASSES_CLASS_ID = 0;
private static final int ANNOTATOR_CLASSES_CLASS_ID = 1;
/**
* This matchings counter needs a {@link MatchingsSearcher} that can create
* pairs of named entities for which the types should be matched to each
* other.
*/
protected MatchingsSearcher<T> matchingsSearcher;
protected List<List<int[]>> counts = new ArrayList<List<int[]>>();
protected SubClassInferencer inferencer;
private UriKBClassifier uriKBClassifier;
public HierarchicalMatchingsCounter(MatchingsSearcher<T> matchingsSearcher, UriKBClassifier uriKBClassifier,
SubClassInferencer inferencer) {
this.matchingsSearcher = matchingsSearcher;
this.uriKBClassifier = uriKBClassifier;
this.inferencer = inferencer;
}
public List<EvaluationCounts> countMatchings(List<T> annotatorResult, List<T> goldStandard) {
EvaluationCounts documentCounts;
List<EvaluationCounts> localCounts = new ArrayList<EvaluationCounts>();
BitSet matchingElements;
BitSet alreadyUsedResults = new BitSet(annotatorResult.size());
T matchedResult;
int matchedResultId;
ClassSet classes;
ClassifyingClassNodeFactory expectedClassesFactory = new ClassifyingClassNodeFactory(EXPECTED_CLASSES_CLASS_ID);
ClassifyingClassNodeFactory annotatorClassesFactory = new ClassifyingClassNodeFactory(
ANNOTATOR_CLASSES_CLASS_ID);
Set<String> types;
for (T expectedElement : goldStandard) {
matchingElements = matchingsSearcher.findMatchings(expectedElement, annotatorResult, alreadyUsedResults);
if (!matchingElements.isEmpty()) {
// We use the first matching as solution for the typing task
matchedResultId = matchingElements.nextSetBit(0);
matchedResult = annotatorResult.get(matchedResultId);
alreadyUsedResults.set(matchedResultId);
// Derive the classes and sub classes for the types given by the
// dataset
classes = new SimpleClassSet();
types = expectedElement.getTypes();
for (String typeURI : types) {
inferencer.inferSubClasses(typeURI, classes, expectedClassesFactory);
}
// Derive the classes and sub classes for the types returned by
// the annotator
types = matchedResult.getTypes();
for (String typeURI : types) {
inferencer.inferSubClasses(typeURI, classes, annotatorClassesFactory);
}
// Count the matchings
documentCounts = countMatchings(classes);
LOGGER.debug("Type matching found {} (classes={}).", documentCounts, classes);
// If the annotator did not return a type of a known KB and the
// gold standard did not contain a type of a known KB
if ((documentCounts.truePositives == 0) && (documentCounts.falseNegatives == 0)
&& (documentCounts.falsePositives == 0)) {
documentCounts.truePositives = 1;
LOGGER.info("Got an entity with a type that is not inside a known KB in the annotator and in the dataset.");
}
} else {
documentCounts = new EvaluationCounts();
documentCounts.falseNegatives = 1;
documentCounts.falsePositives = 1;
}
localCounts.add(documentCounts);
}
for (int i = 0; i < annotatorResult.size(); ++i) {
if(!alreadyUsedResults.get(i)) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("found a false positive. {}", annotatorResult.get(i));
}
localCounts.add(new EvaluationCounts(0, 1, 0));
}
}
return localCounts;
}
private EvaluationCounts countMatchings(ClassSet classes) {
EvaluationCounts documentCounts = new EvaluationCounts();
Iterator<ClassNode> iterator = classes.iterator();
ClassifiedClassNode node;
while (iterator.hasNext()) {
// At this point, every ClassNode should be a ClassifiedClassNode
node = (ClassifiedClassNode) iterator.next();
if (uriKBClassifier.containsKBUri(node.getUris())) {
if (node.getClassIds().contains(EXPECTED_CLASSES_CLASS_ID)) {
if (node.getClassIds().contains(ANNOTATOR_CLASSES_CLASS_ID)) {
++documentCounts.truePositives;
} else {
++documentCounts.falseNegatives;
}
} else if (node.getClassIds().contains(ANNOTATOR_CLASSES_CLASS_ID)) {
++documentCounts.falsePositives;
}
}
}
return documentCounts;
}
public static int getIntersectionSize(Set<String> set1, Set<String> set2) {
Set<String> smallSet, largeSet;
if (set1.size() > set2.size()) {
smallSet = set2;
largeSet = set1;
} else {
smallSet = set1;
largeSet = set2;
}
int count = 0;
for (String e : smallSet) {
if (largeSet.contains(e)) {
++count;
}
}
return count;
}
}