package edu.stanford.nlp.ie.machinereading; import java.io.IOException; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; import edu.stanford.nlp.ie.machinereading.structure.MachineReadingAnnotations; import edu.stanford.nlp.ie.machinereading.structure.RelationMention; import edu.stanford.nlp.ling.CoreAnnotations; import edu.stanford.nlp.pipeline.Annotation; import edu.stanford.nlp.util.CoreMap; /** * Simple extractor which combines several other Extractors. Currently only works with RelationMentions. * Also note that this implementation uses Sets and will mangle the original order of RelationMentions. * * @author David McClosky * */ public class ExtractorMerger implements Extractor { private static final long serialVersionUID = 1L; private static final Logger logger = Logger.getLogger(ExtractorMerger.class.getName()); private Extractor[] extractors; public ExtractorMerger(Extractor[] extractors) { if (extractors.length < 2) { throw new IllegalArgumentException("We need at least 2 extractors for ExtractorMerger to make sense."); } this.extractors = extractors; } public void annotate(Annotation dataset) { // TODO for now, we only merge RelationMentions logger.info("Extractor 0 annotating dataset."); extractors[0].annotate(dataset); // store all the RelationMentions per sentence List<Set<RelationMention>> allRelationMentions = new ArrayList<>(); for (CoreMap sentence : dataset.get(CoreAnnotations.SentencesAnnotation.class)) { List<RelationMention> relationMentions = sentence.get(MachineReadingAnnotations.RelationMentionsAnnotation.class); Set<RelationMention> uniqueRelationMentions = new HashSet<>(relationMentions); allRelationMentions.add(uniqueRelationMentions); } // skip first extractor since we did it at the top for (int extractorIndex = 1; extractorIndex < extractors.length; extractorIndex++) { logger.info("Extractor " + extractorIndex + " annotating dataset."); Extractor extractor = extractors[extractorIndex]; extractor.annotate(dataset); // walk through all sentences and merge our RelationMentions with the combined set int sentenceIndex = 0; for (CoreMap sentence : dataset.get(CoreAnnotations.SentencesAnnotation.class)) { List<RelationMention> relationMentions = sentence.get(MachineReadingAnnotations.RelationMentionsAnnotation.class); allRelationMentions.get(sentenceIndex).addAll(relationMentions); } } // put all merged relations back into the dataset int sentenceIndex = 0; for (CoreMap sentence : dataset.get(CoreAnnotations.SentencesAnnotation.class)) { Set<RelationMention> uniqueRelationMentions = allRelationMentions.get(sentenceIndex); List<RelationMention> relationMentions = new ArrayList<>(uniqueRelationMentions); sentence.set(MachineReadingAnnotations.RelationMentionsAnnotation.class, relationMentions); sentenceIndex++; } } public static Extractor buildRelationExtractorMerger(String[] extractorModelNames) { BasicRelationExtractor[] relationExtractorComponents = new BasicRelationExtractor[extractorModelNames.length]; for (int i = 0; i < extractorModelNames.length; i++) { String modelName = extractorModelNames[i]; logger.info("Loading model " + i + " for model merging from " + modelName); try { relationExtractorComponents[i] = BasicRelationExtractor.load(modelName); } catch (IOException e) { logger.severe("Error loading model:"); e.printStackTrace(); } catch (ClassNotFoundException e) { logger.severe("Error loading model:"); e.printStackTrace(); } } ExtractorMerger relationExtractor = new ExtractorMerger(relationExtractorComponents); return relationExtractor; } public void setLoggerLevel(Level level) { logger.setLevel(level); } // stubs required by Extractor interface -- they don't do anything since this model is not trainable or savable public void save(String path) throws IOException { } public void train(Annotation dataset) { } }