package com.formulasearchengine.mathosphere.mlp.text; import com.formulasearchengine.mathosphere.mlp.cli.EvalCommandConfig; import com.formulasearchengine.mathosphere.mlp.cli.MachineLearningDefinienExtractionConfig; import com.formulasearchengine.mathosphere.mlp.pojos.*; import com.formulasearchengine.mlp.evaluation.pojo.GoldEntry; import com.formulasearchengine.mlp.evaluation.pojo.IdentifierDefinition; import com.google.common.collect.HashMultiset; import com.google.common.collect.Lists; import com.google.common.collect.Multiset; import org.apache.flink.api.common.functions.MapFunction; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.*; import java.util.stream.Collectors; public class SimpleFeatureExtractor implements MapFunction<ParsedWikiDocument, WikiDocumentOutput> { private static final Logger LOGGER = LoggerFactory.getLogger(SimpleFeatureExtractor.class); private final MachineLearningDefinienExtractionConfig config; private final List<GoldEntry> goldEntries; public SimpleFeatureExtractor(MachineLearningDefinienExtractionConfig config, List<GoldEntry> goldEntries) { this.config = config; this.goldEntries = goldEntries; } @Override public WikiDocumentOutput map(ParsedWikiDocument doc) throws Exception { List<Relation> foundFeatures = Lists.newArrayList(); List<Sentence> sentences = doc.getSentences(); Map<String, Integer> identifierSentenceDistanceMap = findSentencesWithIdentifierFirstOccurrences(sentences, doc.getIdentifiers()); Multiset<String> frequencies = aggregateWords(sentences); int maxFrequency = calculateMax(frequencies); GoldEntry goldEntry = goldEntries.stream().filter(e -> e.getTitle().equals(doc.getTitle().replaceAll(" ", "_"))).findFirst().get(); final Integer fid = Integer.parseInt(goldEntry.getFid()); final MathTag seed = doc.getFormulas() .stream().filter(e -> e.getMarkUpType().equals(WikiTextUtils.MathMarkUpType.LATEX)).collect(Collectors.toList()) .get(fid); for (int i = 0; i < sentences.size(); i++) { Sentence sentence = sentences.get(i); if (!sentence.getIdentifiers().isEmpty()) { LOGGER.debug("sentence {}", sentence); } Set<String> identifiers = sentence.getIdentifiers(); identifiers.retainAll(seed.getIdentifiers(config).elementSet()); SimplePatternMatcher matcher = SimplePatternMatcher.generatePatterns(identifiers); Collection<Relation> foundMatches = matcher.match(sentence, doc); for (Relation match : foundMatches) { LOGGER.debug("found match {}", match); int freq = frequencies.count(match.getSentence().getWords().get(match.getWordPosition()).toLowerCase()); match.setRelativeTermFrequency((double) freq / (double) maxFrequency); if (i - identifierSentenceDistanceMap.get(match.getIdentifier()) < 0) { throw new RuntimeException("Cannot find identifier before first occurence"); } match.setDistanceFromFirstIdentifierOccurence((double) (i - identifierSentenceDistanceMap.get(match.getIdentifier())) / (double) doc.getSentences().size()); match.setRelevance(matchesGold(match, goldEntry) ? 2 : 0); foundFeatures.add(match); } } LOGGER.info("extracted {} relations from {}", foundFeatures.size(), doc.getTitle()); WikiDocumentOutput result = new WikiDocumentOutput(doc.getTitle(), goldEntry.getqID(), foundFeatures, null); result.setMaxSentenceLength(doc.getSentences().stream().map(s -> s.getWords().size()).max(Comparator.naturalOrder()).get()); return result; } public boolean matchesGold(Relation relation, GoldEntry gold) { return matchesGold(relation.getIdentifier(), relation.getDefinition(), gold); } public boolean matchesGold(String identifier, String definiens, GoldEntry gold) { List<IdentifierDefinition> identifierDefinitions = gold.getDefinitions(); return identifierDefinitions.contains(new IdentifierDefinition(identifier, definiens.replaceAll("\\[|\\]", "").trim().toLowerCase())); } private Map<String, Integer> findSentencesWithIdentifierFirstOccurrences(List<Sentence> sentences, Collection<String> identifiers) { Map<String, Integer> result = new HashMap<>(); for (String identifier : identifiers) { for (int i = 0; i < sentences.size(); i++) { Sentence sentence = sentences.get(i); if (sentence.contains(identifier)) { result.put(identifier, i); break; } } } return result; } /** * Aggregates the words to make counting easy. * * @param sentences from witch to aggregate the words. * @return Multiset with an entry for every word. */ private Multiset<String> aggregateWords(List<Sentence> sentences) { Multiset<String> counts = HashMultiset.create(); for (Sentence sentence : sentences) { for (Word word : sentence.getWords()) { if (word.getWord().length() >= 3) { counts.add(word.getWord().toLowerCase()); } } } return counts; } private int calculateMax(Multiset<String> frequencies) { Multiset.Entry<String> max = Collections.max(frequencies.entrySet(), (e1, e2) -> Integer.compare(e1.getCount(), e2.getCount())); return max.getCount(); } }