package edu.stanford.nlp.pipeline; import junit.framework.TestCase; import edu.stanford.nlp.ie.util.*; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.ling.CoreAnnotations; import edu.stanford.nlp.util.CoreMap; import java.io.*; import java.util.*; public class KBPAnnotatorBenchmark extends TestCase { public HashMap<String,String> docIDToText; public HashMap<String,Set<String>> docIDToRelations; public StanfordCoreNLP pipeline; public String KBP_DOCS_DIR; public String GOLD_RELATIONS_PATH; public double KBP_MINIMUM_SCORE; public void loadGoldData() { // initialize HashMaps docIDToText = new HashMap<String,String>(); docIDToRelations = new HashMap<String,Set<String>>(); // load the gold relations from gold relations file List<String> goldRelationLines = IOUtils.linesFromFile(GOLD_RELATIONS_PATH); for (String relationLine : goldRelationLines) { String[] docIDAndRelation = relationLine.split("\t"); if (docIDToRelations.get(docIDAndRelation[0]) == null) { docIDToRelations.put(docIDAndRelation[0], new HashSet<String>()); } docIDToRelations.get(docIDAndRelation[0]).add(docIDAndRelation[1]); } // load the text for each docID File directoryWithDocs = new File(KBP_DOCS_DIR); File[] allFiles = directoryWithDocs.listFiles(); for (File kbpTestDocFile : allFiles) { String kbpTestDocID = kbpTestDocFile.getName(); String kbpTestDocPath = kbpTestDocFile.getAbsolutePath(); String kbpTestDocContents = IOUtils.stringFromFile(kbpTestDocPath); docIDToText.put(kbpTestDocID, kbpTestDocContents); } } private String convertRelationName(String relationName) { /*if (relationName.equals("org:top_members/employees")) { return "org:top_members_employees"; }*/ if (relationName.equals("per:employee_of")) { return "per:employee_or_member_of"; } if (relationName.equals("per:stateorprovinces_of_residence")) { return "per:statesorprovinces_of_residence"; } if (relationName.equals("org:number_of_employees/members")) { return "org:number_of_employees_members"; } if (relationName.equals("org:stateorprovince_of_headquarters")) { return "org:stateprovince_of_headquarters"; } if (relationName.equals("per:other_family")) { return "per:otherfamily"; } if (relationName.equals("org:founded")) { return "org:date_founded"; } if (relationName.equals("org:political/religious_affiliation")) { return "org:political_religious_affiliation"; } return relationName; } public Set<String> convertKBPTriplesToStrings(List<RelationTriple> relationTripleList) { HashSet<String> foundRelationStrings = new HashSet<String>(); for (RelationTriple rt : relationTripleList) { String relationName = convertRelationName(rt.relationGloss()); String relationString = relationName+"("+rt.subjectGloss()+","+rt.objectGloss()+")"; foundRelationStrings.add(relationString); } return foundRelationStrings; } public void testKBPAnnotatorResults() { int totalGoldRelations = 0; int totalCorrectFoundRelations = 0; int totalWrongFoundRelations = 0; int totalGuessRelations = 0; double finalF1 = 0.0; for (String docID : docIDToText.keySet()) { System.out.println("---"); System.out.println(docID); Annotation currAnnotation = new Annotation(docIDToText.get(docID)); pipeline.annotate(currAnnotation); // increment number of seen gold relations int docGoldRelationSetSize = 0; if (docIDToRelations.get(docID) != null) { docGoldRelationSetSize = docIDToRelations.get(docID).size(); } totalGoldRelations += docGoldRelationSetSize; ArrayList<RelationTriple> relationTriplesForThisDoc = new ArrayList<RelationTriple>(); for (CoreMap sentence : currAnnotation.get(CoreAnnotations.SentencesAnnotation.class)) { List<RelationTriple> rtList = sentence.get(CoreAnnotations.KBPTriplesAnnotation.class); for (RelationTriple rt : rtList) { System.out.println("\t"+rt.toString()); relationTriplesForThisDoc.add(rt); } } Set<String> foundRelationStrings = convertKBPTriplesToStrings(relationTriplesForThisDoc); HashSet<String> intersectionOfFoundAndGold = new HashSet<String>(foundRelationStrings); if (docIDToRelations.get(docID) != null) { intersectionOfFoundAndGold.retainAll(docIDToRelations.get(docID)); totalCorrectFoundRelations += (intersectionOfFoundAndGold.size()); totalWrongFoundRelations += (foundRelationStrings.size()-intersectionOfFoundAndGold.size()); } else { totalWrongFoundRelations += foundRelationStrings.size(); } totalGuessRelations += foundRelationStrings.size(); System.out.println("curr score: "); double recall = (((double) totalCorrectFoundRelations)/((double) totalGoldRelations)); double precision = (((double) totalCorrectFoundRelations)/((double) totalGuessRelations)); System.out.println("\trecall: "+recall); System.out.println("\tprecision: "+precision); double f1 = (2 * (precision * recall))/(precision + recall); System.out.println("\tf1: "+f1); finalF1 = f1; } // check final F1 score is assertTrue("f1 score: " + finalF1 +" is below threshold of "+KBP_MINIMUM_SCORE , finalF1 >= KBP_MINIMUM_SCORE); } }