package com.formulasearchengine.mathosphere.mlp.text; import com.formulasearchengine.mathosphere.mlp.cli.FlinkMlpCommandConfig; import com.formulasearchengine.mathosphere.mlp.pojos.*; import com.formulasearchengine.mlp.evaluation.pojo.GoldEntry; import com.google.common.collect.HashMultiset; import com.google.common.collect.Multiset; import org.apache.flink.api.common.functions.MapFunction; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.*; import java.util.stream.Collectors; import static com.formulasearchengine.mathosphere.utils.GoldUtil.getGoldEntryByTitle; import static com.formulasearchengine.mathosphere.utils.GoldUtil.matchesGold; /** * Extracts simple features like pattern matching and word counts. * Use this class to generate the features of the document. */ public class SimpleFeatureExtractorMapper implements MapFunction<ParsedWikiDocument, WikiDocumentOutput> { private static final Logger LOGGER = LoggerFactory.getLogger(SimpleFeatureExtractorMapper.class); private final FlinkMlpCommandConfig config; private final List<GoldEntry> goldEntries; private final boolean extractionForTraining; public SimpleFeatureExtractorMapper(FlinkMlpCommandConfig config, List<GoldEntry> goldEntries) { this.config = config; this.goldEntries = goldEntries; extractionForTraining = goldEntries != null && !goldEntries.isEmpty(); } @Override public WikiDocumentOutput map(ParsedWikiDocument doc) throws Exception { List<Relation> allIdentifierDefininesCandidates = new ArrayList<>(); List<Sentence> sentences = doc.getSentences(); Map<String, Integer> identifierSentenceDistanceMap = findSentencesWithIdentifierFirstOccurrences(sentences, doc.getIdentifiers()); Multiset<String> frequencies = aggregateWords(sentences); int maxFrequency = getMaxFrequency(frequencies, doc.getTitle()); if (extractionForTraining) { return getIdentifiersWithGoldInfo(doc, allIdentifierDefininesCandidates, sentences, identifierSentenceDistanceMap, frequencies, maxFrequency); } else { return getAllIdentifiers(doc, allIdentifierDefininesCandidates, sentences, identifierSentenceDistanceMap, frequencies, maxFrequency); } } private WikiDocumentOutput getIdentifiersWithGoldInfo(ParsedWikiDocument doc, List<Relation> allIdentifierDefininesCandidates, List<Sentence> sentences, Map<String, Integer> identifierSentenceDistanceMap, Multiset<String> frequencies, double maxFrequency) { GoldEntry goldEntry = getGoldEntryByTitle(goldEntries, doc.getTitle()); final Integer fid = Integer.parseInt(goldEntry.getFid()); MathTag seed = doc.getFormulas() .stream().filter(e -> e.getMarkUpType().equals(WikiTextUtils.MathMarkUpType.LATEX)).collect(Collectors.toList()) .get(fid); for (int i = 0; i < sentences.size(); i++) { Sentence sentence = sentences.get(i); if (!sentence.getIdentifiers().isEmpty()) { LOGGER.debug("sentence {}", sentence); } Set<String> identifiers = sentence.getIdentifiers(); //only identifiers that were extracted by the MPL pipeline identifiers.retainAll(seed.getIdentifiers(config).elementSet()); SimplePatternMatcher matcher = SimplePatternMatcher.generatePatterns(identifiers); Collection<Relation> foundMatches = matcher.match(sentence, doc); for (Relation match : foundMatches) { List<String> identifiersInGold = goldEntry.getDefinitions().stream().map(id -> id.getIdentifier()).collect(Collectors.toList()); //take only the identifiers that were extracted correctly to avoid false negatives in the training set. if (identifiersInGold.contains(match.getIdentifier())) { LOGGER.debug("found match {}", match); int freq = frequencies.count(match.getSentence().getWords().get(match.getWordPosition()).toLowerCase()); match.setRelativeTermFrequency((double) freq / maxFrequency); if (i - identifierSentenceDistanceMap.get(match.getIdentifier()) < 0) { throw new RuntimeException("Cannot find identifier before first occurence"); } match.setDistanceFromFirstIdentifierOccurence((double) (i - identifierSentenceDistanceMap.get(match.getIdentifier())) / (double) doc.getSentences().size()); match.setRelevance(matchesGold(match.getIdentifier(), match.getDefinition(), goldEntry) ? 2 : 0); allIdentifierDefininesCandidates.add(match); } } } LOGGER.info("extracted {} relations from {}", allIdentifierDefininesCandidates.size(), doc.getTitle()); WikiDocumentOutput result = new WikiDocumentOutput(doc.getTitle(), goldEntry.getqID(), allIdentifierDefininesCandidates, null); result.setMaxSentenceLength(doc.getSentences().stream().map(s -> s.getWords().size()).max(Comparator.naturalOrder()).get()); return result; } private WikiDocumentOutput getAllIdentifiers(ParsedWikiDocument doc, List<Relation> allIdentifierDefininesCandidates, List<Sentence> sentences, Map<String, Integer> identifierSentenceDistanceMap, Multiset<String> frequencies, double maxFrequency) { for (int i = 0; i < sentences.size(); i++) { Sentence sentence = sentences.get(i); if (!sentence.getIdentifiers().isEmpty()) { LOGGER.debug("sentence {}", sentence); } Set<String> identifiers = sentence.getIdentifiers(); //only identifiers that were extracted by the MPL pipeline SimplePatternMatcher matcher = SimplePatternMatcher.generatePatterns(identifiers); Collection<Relation> foundMatches = matcher.match(sentence, doc); for (Relation match : foundMatches) { //take only the identifiers that were extracted correctly to avoid false negatives in the training set. LOGGER.debug("found match {}", match); int freq = frequencies.count(match.getSentence().getWords().get(match.getWordPosition()).toLowerCase()); match.setRelativeTermFrequency((double) freq / maxFrequency); if (i - identifierSentenceDistanceMap.get(match.getIdentifier()) < 0) { throw new RuntimeException("Cannot find identifier before first occurence"); } match.setDistanceFromFirstIdentifierOccurence((double) (i - identifierSentenceDistanceMap.get(match.getIdentifier())) / (double) doc.getSentences().size()); allIdentifierDefininesCandidates.add(match); } } WikiDocumentOutput result = new WikiDocumentOutput(doc.getTitle(), "-1", allIdentifierDefininesCandidates, doc.getIdentifiers()); Optional<Integer> lengthOfLongestSentence = doc.getSentences().stream().map(s -> s.getWords().size()).max(Comparator.naturalOrder()); //one as save value, since 0 would lead to NAN in division. result.setMaxSentenceLength(lengthOfLongestSentence.isPresent() ? lengthOfLongestSentence.get() : 1); return result; } private Map<String, Integer> findSentencesWithIdentifierFirstOccurrences(List<Sentence> sentences, Collection<String> identifiers) { Map<String, Integer> result = new HashMap<>(); for (String identifier : identifiers) { for (int i = 0; i < sentences.size(); i++) { Sentence sentence = sentences.get(i); if (sentence.contains(identifier)) { result.put(identifier, i); break; } } } return result; } /** * Aggregates the words to make counting easy. * * @param sentences from witch to aggregate the words. * @return Multiset with an entry for every word. */ private Multiset<String> aggregateWords(List<Sentence> sentences) { Multiset<String> counts = HashMultiset.create(); for (Sentence sentence : sentences) { for (Word word : sentence.getWords()) { if (word.getWord().length() >= 3) { counts.add(word.getWord().toLowerCase()); } } } return counts; } private int getMaxFrequency(Multiset<String> frequencies, String title) { try { Multiset.Entry<String> max = Collections.max(frequencies.entrySet(), (e1, e2) -> Integer.compare(e1.getCount(), e2.getCount())); return max.getCount(); } catch (NoSuchElementException e) { //no max present LOGGER.error("Error in " + title + "Message: " + e.getMessage(), e); //1 as save value if anything goes wrong return 1; } } }