package edu.stanford.nlp.ie; import edu.stanford.nlp.ie.machinereading.structure.Span; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.simple.Sentence; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.util.ConfusionMatrix; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.StringUtils; import java.io.*; import java.text.DecimalFormat; import java.util.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.Stream; import static edu.stanford.nlp.ie.KBPRelationExtractor.NERTag.*; import static edu.stanford.nlp.util.logging.Redwood.Util.endTrack; import static edu.stanford.nlp.util.logging.Redwood.Util.forceTrack; import static edu.stanford.nlp.util.logging.Redwood.log; /** * An interface for a KBP-style relation extractor * * @author Gabor Angeli */ public interface KBPRelationExtractor { /** * Classify the given sentence into the relation it expresses, with the associated * confidence. */ Pair<String,Double> classify(KBPInput input); /** * The special tag for no relation. */ String NO_RELATION = "no_relation"; /** * A list of valid KBP NER tags. */ enum NERTag { // ENUM_NAME NAME SHORT_NAME IS_REGEXNER_TYPE CAUSE_OF_DEATH("CAUSE_OF_DEATH", "COD", true), // note: these names must be upper case CITY("CITY", "CIT", true), // furthermore, DO NOT change the short names, or else serialization may break COUNTRY("COUNTRY", "CRY", true), CRIMINAL_CHARGE("CRIMINAL_CHARGE", "CC", true), DATE("DATE", "DT", false), IDEOLOGY("IDEOLOGY", "IDY", true), LOCATION("LOCATION", "LOC", false), MISC("MISC", "MSC", false), MODIFIER("MODIFIER", "MOD", false), NATIONALITY("NATIONALITY", "NAT", true), NUMBER("NUMBER", "NUM", false), ORGANIZATION("ORGANIZATION", "ORG", false), PERSON("PERSON", "PER", false), RELIGION("RELIGION", "REL", true), STATE_OR_PROVINCE("STATE_OR_PROVINCE", "ST", true), TITLE("TITLE", "TIT", true), URL("URL", "URL", true), DURATION("DURATION", "DUR", false), GPE("GPE", "GPE", false), // note(chaganty): This NER tag is solely used in the cold-start system for entities. // SCHOOL ("SCHOOL", "SCH", true), ; /** * The full name of this NER tag, as would come out of our NER or RegexNER system */ public final String name; /** * A short name for this NER tag, intended for compact serialization */ public final String shortName; /** * If true, this NER tag is not in the standard NER set, but is annotated via RegexNER */ public final boolean isRegexNERType; NERTag(String name, String shortName, boolean isRegexNERType) { this.name = name; this.shortName = shortName; this.isRegexNERType = isRegexNERType; } /** Find the slot for a given name */ public static Optional<NERTag> fromString(String name) { // Early termination if (StringUtils.isNullOrEmpty(name)) { return Optional.empty(); } // Cycle known NER tags name = name.toUpperCase(); for (NERTag slot : NERTag.values()) { if (slot.name.equals(name)) return Optional.of(slot); } for (NERTag slot : NERTag.values()) { if (slot.shortName.equals(name)) return Optional.of(slot); } // Some quick fixes return Optional.empty(); } } /** * Known relation types (last updated for the 2013 shared task). * * Note that changing the constants here can have far-reaching consequences in loading serialized * models, and various bits of code that have been hard-coded to these relation types (e.g., the various * consistency filters). * * <p> * <i>Note:</i> Neither per:spouse, org:founded_by, or X:organizations_founded are SINGLE relations * in the spec - these are made single here because our system otherwise over-predicts them. * </p> * * @author Gabor Angeli */ enum RelationType { PER_ALTERNATE_NAMES("per:alternate_names", true, 10, PERSON, Cardinality.LIST, new NERTag[]{PERSON, MISC}, new String[]{"NNP"}, 0.0353027270308107100), PER_CHILDREN("per:children", true, 5, PERSON, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP"}, 0.0058428110284504410), PER_CITIES_OF_RESIDENCE("per:cities_of_residence", true, 5, PERSON, Cardinality.LIST, new NERTag[]{CITY,}, new String[]{"NNP"}, 0.0136105679675116560), PER_CITY_OF_BIRTH("per:city_of_birth", true, 3, PERSON, Cardinality.SINGLE, new NERTag[]{CITY,}, new String[]{"NNP"}, 0.0358146961159769100), PER_CITY_OF_DEATH("per:city_of_death", true, 3, PERSON, Cardinality.SINGLE, new NERTag[]{CITY,}, new String[]{"NNP"}, 0.0102003332137774650), PER_COUNTRIES_OF_RESIDENCE("per:countries_of_residence", true, 5, PERSON, Cardinality.LIST, new NERTag[]{COUNTRY,}, new String[]{"NNP"}, 0.0107788293552082020), PER_COUNTRY_OF_BIRTH("per:country_of_birth", true, 3, PERSON, Cardinality.SINGLE, new NERTag[]{COUNTRY,}, new String[]{"NNP"}, 0.0223444134627622040), PER_COUNTRY_OF_DEATH("per:country_of_death", true, 3, PERSON, Cardinality.SINGLE, new NERTag[]{COUNTRY,}, new String[]{"NNP"}, 0.0060626395621941200), PER_EMPLOYEE_OF("per:employee_of", true, 10, PERSON, Cardinality.LIST, new NERTag[]{ORGANIZATION, COUNTRY, STATE_OR_PROVINCE, CITY}, new String[]{"NNP"}, 2.0335281901169719200), PER_LOC_OF_BIRTH("per:LOCATION_of_birth", true, 3, PERSON, Cardinality.LIST, new NERTag[]{CITY, STATE_OR_PROVINCE, COUNTRY}, new String[]{"NNP"}, 0.0165825918941120660), PER_LOC_OF_DEATH("per:LOCATION_of_death", true, 3, PERSON, Cardinality.LIST, new NERTag[]{CITY, STATE_OR_PROVINCE, COUNTRY}, new String[]{"NNP"}, 0.0165825918941120660), PER_LOC_OF_RESIDENCE("per:LOCATION_of_residence", true, 3, PERSON, Cardinality.LIST, new NERTag[]{STATE_OR_PROVINCE,}, new String[]{"NNP"}, 0.0165825918941120660), PER_MEMBER_OF("per:member_of", true, 10, PERSON, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP"}, 0.0521716745149309900), PER_ORIGIN("per:origin", true, 10, PERSON, Cardinality.LIST, new NERTag[]{NATIONALITY, COUNTRY}, new String[]{"NNP"}, 0.0069795559463618380), PER_OTHER_FAMILY("per:other_family", true, 5, PERSON, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP"}, 2.7478566717959990E-5), PER_PARENTS("per:parents", true, 5, PERSON, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP"}, 0.0032222235077692030), PER_SCHOOLS_ATTENDED("per:schools_attended", true, 5, PERSON, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP"}, 0.0054696810172276150), PER_SIBLINGS("per:siblings", true, 5, PERSON, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP"}, 1.000000000000000e-99), PER_SPOUSE("per:spouse", true, 3, PERSON, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP"}, 0.0164075968113292680), PER_STATE_OR_PROVINCES_OF_BIRTH("per:stateorprovince_of_birth", true, 3, PERSON, Cardinality.SINGLE, new NERTag[]{STATE_OR_PROVINCE,}, new String[]{"NNP"}, 0.0165825918941120660), PER_STATE_OR_PROVINCES_OF_DEATH("per:stateorprovince_of_death", true, 3, PERSON, Cardinality.SINGLE, new NERTag[]{STATE_OR_PROVINCE,}, new String[]{"NNP"}, 0.0050083303444366030), PER_STATE_OR_PROVINCES_OF_RESIDENCE("per:stateorprovinces_of_residence", true, 5, PERSON, Cardinality.LIST, new NERTag[]{STATE_OR_PROVINCE,}, new String[]{"NNP"}, 0.0066787379528178550), PER_AGE("per:age", true, 3, PERSON, Cardinality.SINGLE, new NERTag[]{NUMBER, DURATION}, new String[]{"CD", "NN"}, 0.0483159977322951300), PER_DATE_OF_BIRTH("per:date_of_birth", true, 3, PERSON, Cardinality.SINGLE, new NERTag[]{DATE}, new String[]{"CD", "NN"}, 0.0743584477791533200), PER_DATE_OF_DEATH("per:date_of_death", true, 3, PERSON, Cardinality.SINGLE, new NERTag[]{DATE}, new String[]{"CD", "NN"}, 0.0189819046406960460), PER_CAUSE_OF_DEATH("per:cause_of_death", true, 3, PERSON, Cardinality.SINGLE, new NERTag[]{CAUSE_OF_DEATH}, new String[]{"NN"}, 1.0123682475037891E-5), PER_CHARGES("per:charges", true, 5, PERSON, Cardinality.LIST, new NERTag[]{CRIMINAL_CHARGE}, new String[]{"NN"}, 3.8614617440501670E-4), PER_RELIGION("per:religion", true, 3, PERSON, Cardinality.SINGLE, new NERTag[]{RELIGION}, new String[]{"NN"}, 7.6650738739572610E-4), PER_TITLE("per:title", true, 15, PERSON, Cardinality.LIST, new NERTag[]{TITLE, MODIFIER}, new String[]{"NN"}, 0.0334283995325751200), ORG_ALTERNATE_NAMES("org:alternate_names", true, 10, ORGANIZATION, Cardinality.LIST, new NERTag[]{ORGANIZATION, MISC}, new String[]{"NNP"}, 0.0552058867767352000), ORG_CITY_OF_HEADQUARTERS("org:city_of_headquarters", true, 3, ORGANIZATION, Cardinality.SINGLE, new NERTag[]{CITY, LOCATION}, new String[]{"NNP"}, 0.0555949254318473740), ORG_COUNTRY_OF_HEADQUARTERS("org:country_of_headquarters", true, 3, ORGANIZATION, Cardinality.SINGLE, new NERTag[]{COUNTRY, NATIONALITY}, new String[]{"NNP"}, 0.0580217167451493100), ORG_FOUNDED_BY("org:founded_by", true, 3, ORGANIZATION, Cardinality.LIST, new NERTag[]{PERSON, ORGANIZATION}, new String[]{"NNP"}, 0.0050806423621154450), ORG_LOC_OF_HEADQUARTERS("org:LOCATION_of_headquarters", true, 10, ORGANIZATION, Cardinality.LIST, new NERTag[]{CITY, STATE_OR_PROVINCE, COUNTRY,}, new String[]{"NNP"}, 0.0555949254318473740), ORG_MEMBER_OF("org:member_of", true, 20, ORGANIZATION, Cardinality.LIST, new NERTag[]{ORGANIZATION, STATE_OR_PROVINCE, COUNTRY,}, new String[]{"NNP"}, 0.0396298781687126140), ORG_MEMBERS("org:members", true, 20, ORGANIZATION, Cardinality.LIST, new NERTag[]{ORGANIZATION, COUNTRY}, new String[]{"NNP"}, 0.0012220730987724312), ORG_PARENTS("org:parents", true, 10, ORGANIZATION, Cardinality.LIST, new NERTag[]{ORGANIZATION,}, new String[]{"NNP"}, 0.0550048593675880200), ORG_POLITICAL_RELIGIOUS_AFFILIATION("org:political/religious_affiliation", true, 5, ORGANIZATION, Cardinality.LIST, new NERTag[]{IDEOLOGY, RELIGION}, new String[]{"NN", "JJ"}, 0.0059266929689578970), ORG_SHAREHOLDERS("org:shareholders", true, 10, ORGANIZATION, Cardinality.LIST, new NERTag[]{PERSON, ORGANIZATION}, new String[]{"NNP"}, 1.1569922828614734E-5), ORG_STATE_OR_PROVINCES_OF_HEADQUARTERS("org:stateorprovince_of_headquarters", true, 3, ORGANIZATION, Cardinality.SINGLE, new NERTag[]{STATE_OR_PROVINCE}, new String[]{"NNP"}, 0.0312619314829170100), ORG_SUBSIDIARIES("org:subsidiaries", true, 20, ORGANIZATION, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP"}, 0.0162412791706679320), ORG_TOP_MEMBERS_SLASH_EMPLOYEES("org:top_members/employees", true, 10, ORGANIZATION, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP"}, 0.0907168724184609800), ORG_DISSOLVED("org:dissolved", true, 3, ORGANIZATION, Cardinality.SINGLE, new NERTag[]{DATE}, new String[]{"CD", "NN"}, 0.0023877428237553656), ORG_FOUNDED("org:founded", true, 3, ORGANIZATION, Cardinality.SINGLE, new NERTag[]{DATE}, new String[]{"CD", "NN"}, 0.0796314401082944800), ORG_NUMBER_OF_EMPLOYEES_SLASH_MEMBERS("org:number_of_employees/members", true, 3, ORGANIZATION, Cardinality.SINGLE, new NERTag[]{NUMBER}, new String[]{"CD", "NN"}, 0.0366274831946870950), ORG_WEBSITE("org:website", true, 3, ORGANIZATION, Cardinality.SINGLE, new NERTag[]{URL}, new String[]{"NNP", "NN"}, 0.0051544006201478640), // Inverse types ORG_EMPLOYEES("org:employees_or_members", false, 68, ORGANIZATION, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_EMPLOYEES("gpe:employees_or_members", false, 10, GPE, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP", "NN"}, 0.0051544006201478640), ORG_STUDENTS("org:students", false, 50, ORGANIZATION, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_BIRTHS_IN_CITY("gpe:births_in_city", false, 50, GPE, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_BIRTHS_IN_STATE_OR_PROVINCE("gpe:births_in_stateorprovince", false, 50, GPE, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_BIRTHS_IN_COUNTRY("gpe:births_in_country", false, 50, GPE, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_RESIDENTS_IN_CITY("gpe:residents_of_city", false, 50, GPE, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_RESIDENTS_IN_STATE_OR_PROVINCE("gpe:residents_of_stateorprovince", false, 50, GPE, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_RESIDENTS_IN_COUNTRY("gpe:residents_of_country", false, 50, GPE, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_DEATHS_IN_CITY("gpe:deaths_in_city", false, 50, GPE, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_DEATHS_IN_STATE_OR_PROVINCE("gpe:deaths_in_stateorprovince", false, 50, GPE, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_DEATHS_IN_COUNTRY("gpe:deaths_in_country", false, 50, GPE, Cardinality.LIST, new NERTag[]{PERSON}, new String[]{"NNP", "NN"}, 0.0051544006201478640), PER_HOLDS_SHARES_IN("per:holds_shares_in", false, 10, PERSON, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_HOLDS_SHARES_IN("gpe:holds_shares_in", false, 10, GPE, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP", "NN"}, 0.0051544006201478640), ORG_HOLDS_SHARES_IN("org:holds_shares_in", false, 10, ORGANIZATION, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP", "NN"}, 0.0051544006201478640), PER_ORGANIZATIONS_FOUNDED("per:organizations_founded", false, 3, PERSON, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_ORGANIZATIONS_FOUNDED("gpe:organizations_founded", false, 3, GPE, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP", "NN"}, 0.0051544006201478640), ORG_ORGANIZATIONS_FOUNDED("org:organizations_founded", false, 3, ORGANIZATION, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP", "NN"}, 0.0051544006201478640), PER_TOP_EMPLOYEE_OF("per:top_member_employee_of", false, 5, PERSON, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_MEMBER_OF("gpe:member_of", false, 10, GPE, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP"}, 0.0396298781687126140), GPE_SUBSIDIARIES("gpe:subsidiaries", false, 10, GPE, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP"}, 0.0396298781687126140), GPE_HEADQUARTERS_IN_CITY("gpe:headquarters_in_city", false, 50, GPE, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_HEADQUARTERS_IN_STATE_OR_PROVINCE("gpe:headquarters_in_stateorprovince", false, 50, GPE, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP", "NN"}, 0.0051544006201478640), GPE_HEADQUARTERS_IN_COUNTRY("gpe:headquarters_in_country", false, 50, GPE, Cardinality.LIST, new NERTag[]{ORGANIZATION}, new String[]{"NNP", "NN"}, 0.0051544006201478640), ; public enum Cardinality { SINGLE, LIST } /** * A canonical name for this relation type. This is the official 2010 relation name, * that has since changed. */ public final String canonicalName; /** * If true, realtation was one of the original (non-inverse) KBP relation. */ public final boolean isOriginalRelation; /** * A guess of the maximum number of results to query for this relation. * Only really relevant for cold start. */ public final int queryLimit; /** * The entity type (left arg type) associated with this relation. That is, either a PERSON or an ORGANIZATION "slot". */ public final NERTag entityType; /** * The cardinality of this entity. That is, can multiple right arguments participate in this relation (born_in vs. lived_in) */ public final Cardinality cardinality; /** * Valid named entity labels for the right argument to this relation */ public final Set<NERTag> validNamedEntityLabels; /** * Valid POS [prefixes] for the right argument to this relation (e.g., can only take nouns, or can only take numbers, etc.) */ public final Set<String> validPOSPrefixes; /** * The prior for how often this relation occurs in the training data. * Note that this prior is not necessarily accurate for the test data. */ public final double priorProbability; RelationType(String canonicalName, boolean isOriginalRelation, int queryLimit, NERTag type, Cardinality cardinality, NERTag[] validNamedEntityLabels, String[] validPOSPrefixes, double priorProbability) { this.canonicalName = canonicalName; this.isOriginalRelation = isOriginalRelation; this.queryLimit = queryLimit; this.entityType = type; this.cardinality = cardinality; this.validNamedEntityLabels = new HashSet<>(Arrays.asList(validNamedEntityLabels)); this.validPOSPrefixes = new HashSet<>(Arrays.asList(validPOSPrefixes)); this.priorProbability = priorProbability; } /** A small cache of names to relation types; we call fromString() a lot in the code, usually expecting it to be very fast */ private static final Map<String, RelationType> cachedFromString = new HashMap<>(); /** Find the slot for a given name */ public static Optional<RelationType> fromString(String name) { if (name == null) { return Optional.empty(); } String originalName = name; if (cachedFromString.get(name) != null) { return Optional.of(cachedFromString.get(name)); } if (cachedFromString.containsKey(name)) { return Optional.empty(); } // Try naive for (RelationType slot : RelationType.values()) { if (slot.canonicalName.equals(name) || slot.name().equals(name)) { cachedFromString.put(originalName, slot); return Optional.of(slot); } } // Replace slashes name = name.toLowerCase().replaceAll("[Ss][Ll][Aa][Ss][Hh]", "/"); for (RelationType slot : RelationType.values()) { if (slot.canonicalName.equalsIgnoreCase(name)) { cachedFromString.put(originalName, slot); return Optional.of(slot); } } cachedFromString.put(originalName, null); return Optional.empty(); } /** * Returns whether two entity types could plausibly have a relation hold between them. * That is, is there a known relation type that would hold between these two entity types. * @param entityType The NER tag of the entity. * @param slotValueType The NER tag of the slot value. * @return True if there is a plausible relation which could occur between these two types. */ public static boolean plausiblyHasRelation(NERTag entityType, NERTag slotValueType) { for (RelationType rel : RelationType.values()) { if (rel.entityType == entityType && rel.validNamedEntityLabels.contains(slotValueType)) { return true; } } return false; } } @SuppressWarnings("unused") class KBPInput { public final Span subjectSpan; public final Span objectSpan; public final NERTag subjectType; public final NERTag objectType; public final Sentence sentence; public KBPInput(Span subjectSpan, Span objectSpan, NERTag subjectType, NERTag objectType, Sentence sentence) { this.subjectSpan = subjectSpan; this.objectSpan = objectSpan; this.subjectType = subjectType; this.objectType = objectType; this.sentence = sentence; } public Sentence getSentence() { return sentence; } public Span getSubjectSpan() { return subjectSpan; } public String getSubjectText() { return StringUtils.join(sentence.originalTexts().subList(subjectSpan.start(), subjectSpan.end()).stream(), " "); } public Span getObjectSpan() { return objectSpan; } public String getObjectText() { return StringUtils.join(sentence.originalTexts().subList(objectSpan.start(), objectSpan.end()).stream(), " "); } @Override public String toString() { return "KBPInput{" + ", subjectSpan=" + subjectSpan + ", objectSpan=" + objectSpan + ", sentence=" + sentence + '}'; } } /** * Read a dataset from a CoNLL formatted input file * @param conllInputFile The input file, formatted as a TSV * @return A list of examples. */ @SuppressWarnings("StatementWithEmptyBody") static List<Pair<KBPInput, String>> readDataset(File conllInputFile) throws IOException { BufferedReader reader = IOUtils.readerFromFile(conllInputFile); List<Pair<KBPInput, String>> examples = new ArrayList<>(); int i = 0; String relation = null; List<String> tokens = new ArrayList<>(); Span subject = new Span(Integer.MAX_VALUE, Integer.MIN_VALUE); NERTag subjectNER = null; Span object = new Span(Integer.MAX_VALUE, Integer.MIN_VALUE); NERTag objectNER = null; String line = reader.readLine(); if (!line.startsWith("#")) { throw new IllegalArgumentException("First line of input file should be header definition"); } while ( (line = reader.readLine()) != null ) { String[] fields = line.split("\t"); if (relation == null) { // Case: read the relation assert fields.length == 1; relation = fields[0]; } else if (fields.length == 9) { // Case: read a token tokens.add(fields[0]); if ("SUBJECT".equals(fields[1])) { subject = new Span(Math.min(subject.start(), i), Math.max(subject.end(), i + 1)); subjectNER = valueOf(fields[2].toUpperCase()); } else if ("OBJECT".equals(fields[3])) { object = new Span(Math.min(object.start(), i), Math.max(object.end(), i + 1)); objectNER = valueOf(fields[4].toUpperCase()); } else if ("-".equals(fields[1]) && "-".equals(fields[3])) { // do nothing } else { throw new IllegalStateException("Could not parse CoNLL file"); } i += 1; } else if (StringUtils.isNullOrEmpty(line.trim())) { // Case: commit a sentence examples.add(Pair.makePair(new KBPInput( subject, object, subjectNER, objectNER, new Sentence(tokens) ), relation)); // (clear the variables) i = 0; relation = null; tokens = new ArrayList<>(); subject = new Span(Integer.MAX_VALUE, Integer.MIN_VALUE); object = new Span(Integer.MAX_VALUE, Integer.MIN_VALUE); } else { throw new IllegalStateException("Could not parse CoNLL file"); } } return examples; } /** A class to compute the accuracy of a relation extractor. */ @SuppressWarnings("unused") class Accuracy { private static class PerRelationStat implements Comparable<PerRelationStat> { public final String name; public final double precision; public final double recall; public final int predictedCount; public final int goldCount; public PerRelationStat(String name, double precision, double recall, int predictedCount, int goldCount) { this.name = name; this.precision = precision; this.recall = recall; this.predictedCount = predictedCount; this.goldCount = goldCount; } public double f1() { if (precision == 0.0 && recall == 0.0) { return 0.0; } else { return 2.0 * precision * recall / (precision + recall); } } @SuppressWarnings("NullableProblems") @Override public int compareTo(PerRelationStat o) { if (this.precision < o.precision) { return -1; } else if (this.precision > o.precision) { return 1; } else { return 0; } } @Override public String toString() { DecimalFormat df = new DecimalFormat("0.00%"); return "[" + name + "] pred/gold: " + predictedCount + "/" + goldCount + " P: " + df.format(precision) + " R: " + df.format(recall) + " F1: " + df.format(f1()); } } private Counter<String> correctCount = new ClassicCounter<>(); private Counter<String> predictedCount = new ClassicCounter<>(); private Counter<String> goldCount = new ClassicCounter<>(); private Counter<String> totalCount = new ClassicCounter<>(); public final ConfusionMatrix<String> confusion = new ConfusionMatrix<>(); public void predict(Set<String> predictedRelationsRaw, Set<String> goldRelationsRaw) { Set<String> predictedRelations = new HashSet<>(predictedRelationsRaw); predictedRelations.remove(NO_RELATION); Set<String> goldRelations = new HashSet<>(goldRelationsRaw); goldRelations.remove(NO_RELATION); // Register the prediction for (String pred : predictedRelations) { if (goldRelations.contains(pred)) { correctCount.incrementCount(pred); } predictedCount.incrementCount(pred); } goldRelations.forEach(goldCount::incrementCount); HashSet<String> allRelations = new HashSet<String>(){{ addAll(predictedRelations); addAll(goldRelations); }}; allRelations.forEach(totalCount::incrementCount); // Register the confusion matrix if (predictedRelations.size() == 1 && goldRelations.size() == 1) { confusion.add(predictedRelations.iterator().next(), goldRelations.iterator().next()); } if (predictedRelations.size() == 1 && goldRelations.isEmpty()) { confusion.add(predictedRelations.iterator().next(), "NR"); } if (predictedRelations.isEmpty() && goldRelations.size() == 1) { confusion.add("NR", goldRelations.iterator().next()); } } public double precision(String relation) { if (predictedCount.getCount(relation) == 0) { return 1.0; } return correctCount.getCount(relation) / predictedCount.getCount(relation); } public double precisionMicro() { if (predictedCount.totalCount() == 0) { return 1.0; } return correctCount.totalCount() / predictedCount.totalCount(); } public double precisionMacro() { double sumPrecision = 0.0; for (String rel : totalCount.keySet()) { sumPrecision += precision(rel); } return sumPrecision / ((double) totalCount.size()); } public double recall(String relation) { if (goldCount.getCount(relation) == 0) { return 0.0; } return correctCount.getCount(relation) / goldCount.getCount(relation); } public double recallMicro() { if (goldCount.totalCount() == 0) { return 0.0; } return correctCount.totalCount() / goldCount.totalCount(); } public double recallMacro() { double sumRecall = 0.0; for (String rel : totalCount.keySet()) { sumRecall += recall(rel); } return sumRecall / ((double) totalCount.size()); } public double f1(String relation) { return 2.0 * precision(relation) * recall(relation) / (precision(relation) + recall(relation)); } public double f1Micro() { return 2.0 * precisionMicro() * recallMicro() / (precisionMicro() + recallMicro()); } public double f1Macro() { return 2.0 * precisionMacro() * recallMacro() / (precisionMacro() + recallMacro()); } public void dumpPerRelationStats(PrintStream out) { List<PerRelationStat> stats = goldCount.keySet().stream().map(relation -> new PerRelationStat(relation, precision(relation), recall(relation), (int) predictedCount.getCount(relation), (int) goldCount.getCount(relation))).collect(Collectors.toList()); Collections.sort(stats); out.println("Per-relation Accuracy"); for (PerRelationStat stat : stats) { out.println(stat); } } public void dumpPerRelationStats() { dumpPerRelationStats(System.out); } public void toString(PrintStream out) { out.println(); out.println("PRECISION (micro average): " + new DecimalFormat("0.000%").format(precisionMicro())); out.println("RECALL (micro average): " + new DecimalFormat("0.000%").format(recallMicro())); out.println("F1 (micro average): " + new DecimalFormat("0.000%").format(f1Micro())); out.println(); out.println("PRECISION (macro average): " + new DecimalFormat("0.000%").format(precisionMacro())); out.println("RECALL (macro average): " + new DecimalFormat("0.000%").format(recallMacro())); out.println("F1 (macro average): " + new DecimalFormat("0.000%").format(f1Macro())); out.println(); } public String toString() { ByteArrayOutputStream bs = new ByteArrayOutputStream(); PrintStream out = new PrintStream(bs); toString(out); return bs.toString(); } /** * A short, single line summary of the micro-precision/recall/f1. */ public String toOneLineString() { return "P: " + new DecimalFormat("0.000%").format(precisionMicro()) + " " + "R: " + new DecimalFormat("0.000%").format(recallMicro()) + " " + "F1: " + new DecimalFormat("0.000%").format(f1Micro()); } } default Accuracy computeAccuracy(Stream<Pair<KBPInput, String>> examples, Optional<PrintStream> predictOut) { forceTrack("Accuracy"); Accuracy accuracy = new Accuracy(); AtomicInteger testI = new AtomicInteger(0); DecimalFormat confidenceFormat = new DecimalFormat("0.0000"); forceTrack("Featurizing"); examples.parallel().map(example -> { Pair<String, Double> predicted = this.classify(example.first); synchronized (accuracy) { accuracy.predict(Collections.singleton(predicted.first), Collections.singleton(example.second)); } if (testI.incrementAndGet() % 1000 == 0) { log(KBPRelationExtractor.class, "[" + testI.get() + "] " + accuracy.toOneLineString()); } return predicted.first + "\t" + confidenceFormat.format(predicted.second); }) .forEachOrdered(line -> { if (predictOut.isPresent()) { predictOut.get().println(line); } }); endTrack("Featurizing"); log(accuracy.toString()); endTrack("Accuracy"); return accuracy; } }