package com.formulasearchengine.mathosphere.mlp.text; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Multiset; import com.google.common.collect.Sets; import com.alexeygrigorev.rseq.BeanMatchers; import com.alexeygrigorev.rseq.Match; import com.alexeygrigorev.rseq.Matchers; import com.alexeygrigorev.rseq.Pattern; import com.alexeygrigorev.rseq.TransformerToElement; import com.alexeygrigorev.rseq.XMatcher; import com.formulasearchengine.mathosphere.mlp.cli.BaseConfig; import com.formulasearchengine.mathosphere.mlp.pojos.MathTag; import com.formulasearchengine.mathosphere.mlp.pojos.Sentence; import com.formulasearchengine.mathosphere.mlp.pojos.Word; import com.formulasearchengine.mathosphere.mlp.rus.RusPosAnnotator; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.Set; import java.util.stream.Collectors; import edu.stanford.nlp.ling.CoreAnnotations.PartOfSpeechAnnotation; import edu.stanford.nlp.ling.CoreAnnotations.SentencesAnnotation; import edu.stanford.nlp.ling.CoreAnnotations.TextAnnotation; import edu.stanford.nlp.ling.CoreAnnotations.TokensAnnotation; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.pipeline.Annotation; import edu.stanford.nlp.pipeline.POSTaggerAnnotator; import edu.stanford.nlp.pipeline.StanfordCoreNLP; import edu.stanford.nlp.util.CoreMap; public class PosTagger { private static BaseConfig config; private static final Logger LOGGER = LoggerFactory.getLogger(PosTagger.class); private static final Set<String> SYMBOLS = ImmutableSet.of("<", "=", ">", "≥", "≤", "|", "/", "\\", "[", "]", "*"); private static final Map<String, String> BRACKET_CODES = ImmutableMap.<String, String>builder() .put("-LRB-", "(").put("-RRB-", ")").put("-LCB-", "{").put("-RCB-", "}").put("-LSB-", "[") .put("-RSB-", "]").build(); public static PosTagger create(BaseConfig cfg) { config = cfg; Properties props = new Properties(); props.put("annotators", "tokenize, ssplit"); props.put("tokenize.options", "untokenizable=firstKeep,strictTreebank3=true," + "ptb3Escaping=true,escapeForwardSlashAsterisk=false"); props.put("ssplit.newlineIsSentenceBreak", "two"); props.put("maxLength", 50); StanfordCoreNLP pipeline = new StanfordCoreNLP(props); if ("en".equals(cfg.getLanguage())) { POSTaggerAnnotator modelBasedPosAnnotator = new POSTaggerAnnotator(config.getModel(), false); pipeline.addAnnotator(modelBasedPosAnnotator); } else if ("ru".equals(cfg.getLanguage())) { pipeline.addAnnotator(new RusPosAnnotator()); } else { throw new IllegalArgumentException("Cannot deal with language " + config.getLanguage()); } return new PosTagger(pipeline); } private final StanfordCoreNLP nlpPipeline; public PosTagger(StanfordCoreNLP nlpPipeline) { this.nlpPipeline = nlpPipeline; } public List<Sentence> process(String cleanText, List<MathTag> formulas) { Map<String, MathTag> formulaIndex = Maps.newHashMap(); Set<String> allIdentifiers = Sets.newHashSet(); formulas.forEach(f -> formulaIndex.put(f.getKey(), f)); //wrap all single character identifiers in \\mathit{} tag formulas.forEach(f -> allIdentifiers.addAll( f.getIdentifiers(config) .stream() .map(e -> e.matches(".") ? "\\mathit{" + e + "}" : e) .collect(Collectors.toList()) )); List<List<Word>> annotated = annotate(cleanText, formulaIndex, allIdentifiers); List<List<Word>> concatenated = concatenateTags(annotated, allIdentifiers); return postprocess(concatenated, formulaIndex, allIdentifiers); } public List<List<Word>> annotate(String cleanText, Map<String, MathTag> formulas, Set<String> allIdentifiers) { Annotation document = new Annotation(cleanText); nlpPipeline.annotate(document); List<List<Word>> result = Lists.newArrayList(); for (CoreMap sentence : document.get(SentencesAnnotation.class)) { List<Word> words = Lists.newArrayList(); final List<CoreLabel> coreLabels = sentence.get(TokensAnnotation.class); for (CoreLabel token : coreLabels) { String textToken = token.get(TextAnnotation.class); String pos = token.get(PartOfSpeechAnnotation.class); if (textToken.startsWith("FORMULA_")) { words.add(new Word(textToken, PosTag.MATH)); } else if (SYMBOLS.contains(textToken)) { words.add(new Word(textToken, PosTag.SYMBOL)); } else if (BRACKET_CODES.containsKey(textToken)) { words.add(new Word(BRACKET_CODES.get(textToken), pos)); } else if (textToken.startsWith("LINK_")) { words.add(new Word(textToken, PosTag.LINK)); } else { words.add(new Word(textToken, pos)); } } result.add(words); } return result; } public static List<Sentence> postprocess(List<List<Word>> input, Map<String, MathTag> formulaIndex, Set<String> allIdentifiers) { List<Sentence> result = Lists.newArrayListWithCapacity(input.size()); for (List<Word> words : input) { Sentence sentence = toSentence(words, formulaIndex, allIdentifiers); result.add(sentence); } return result; } public static Sentence toSentence(List<Word> input, Map<String, MathTag> formulaIndex, Set<String> allIdentifiers) { List<Word> words = Lists.newArrayListWithCapacity(input.size()); Set<String> sentenceIdentifiers = Sets.newHashSet(); List<MathTag> formulas = Lists.newArrayList(); for (Word w : input) { String word = w.getWord(); String pos = w.getPosTag(); if (allIdentifiers.contains(word) && !PosTag.IDENTIFIER.equals(pos)) { words.add(new Word(word, PosTag.IDENTIFIER)); sentenceIdentifiers.add(word); continue; } if (PosTag.MATH.equals(pos)) { String formulaKey = word; if (word.length() > 40) { formulaKey = word.substring(0, 40); } MathTag formula = formulaIndex.get(formulaKey); if (formula == null) { LOGGER.warn("formula {} does not exist", word); words.add(w); continue; } formulas.add(formula); Multiset<String> formulaIdentifiers = formula.getIdentifiers(config); // only one occurrence of one single idendifier if (formulaIdentifiers.size() == 1) { String id = Iterables.get(formulaIdentifiers, 0); LOGGER.debug("convering formula {} to idenfier {}", formula.getKey(), id); words.add(new Word(id, PosTag.IDENTIFIER)); sentenceIdentifiers.add(id); } else { words.add(w); } if (word.length() > 40) { String rest = word.substring(40, word.length()); words.add(new Word(rest, PosTag.SUFFIX)); } continue; } words.add(w); } return new Sentence(words, sentenceIdentifiers, formulas); } public static List<List<Word>> concatenateTags(List<List<Word>> sentences, Set<String> allIdentifiers) { List<List<Word>> results = Lists.newArrayListWithCapacity(sentences.size()); for (List<Word> sentence : sentences) { List<Word> res = postprocessSentence(sentence, allIdentifiers); results.add(res); } return results; } private static List<Word> postprocessSentence(List<Word> sentence, Set<String> allIdentifiers) { // links List<Word> result; if (config.getUseTeXIdentifiers()) { result = sentence; } else { result = concatenateLinks(sentence, allIdentifiers); } // noun phrases result = concatenateSuccessiveNounsToNounSequence(result); result = contatenateSuccessive2Tags(result, PosTag.ADJECTIVE, PosTag.NOUN, PosTag.NOUN_PHRASE); result = contatenateSuccessive2Tags(result, PosTag.ADJECTIVE, PosTag.NOUN_PLURAL, PosTag.NOUN_PHRASE); result = contatenateSuccessive2Tags(result, PosTag.ADJECTIVE, PosTag.NOUN_SEQUENCE, PosTag.NOUN_SEQUENCE_PHRASE); return result; } public static List<Word> concatenateLinks(List<Word> in, Set<String> allIdentifiers) { Pattern<Word> linksPattern = Pattern.create(pos(PosTag.QUOTE), anyWord().oneOrMore() .captureAs("link"), pos(PosTag.UNQUOTE)); return linksPattern.replaceToOne(in, new TransformerToElement<Word>() { @Override public Word transform(Match<Word> match) { List<Word> words = match.getCapturedGroup("link"); if (words.size() == 1 && allIdentifiers.contains("\\mathit{" + words.get(0).getWord() + "}")) { return new Word(joinWords(words), PosTag.IDENTIFIER); } else { return new Word(joinWords(words), PosTag.LINK); } } }); } public static List<Word> concatenateSuccessiveNounsToNounSequence(List<Word> in) { XMatcher<Word> noun = posIn(PosTag.NOUN, PosTag.NOUN_PLURAL); Pattern<Word> nounPattern = Pattern.create(noun.oneOrMore()); return nounPattern.replaceToOne(in, new TransformerToElement<Word>() { @Override public Word transform(Match<Word> match) { List<Word> words = match.getMatchedSubsequence(); if (words.size() == 1) { return words.get(0); } return new Word(joinWords(words), PosTag.NOUN_SEQUENCE); } }); } public static List<Word> contatenateSuccessive2Tags(List<Word> in, String tag1, String tag2, String outputTag) { Pattern<Word> pattern = Pattern.create(pos(tag1), pos(tag2)); return pattern.replaceToOne(in, m -> new Word(joinWords(m.getMatchedSubsequence()), outputTag)); } public static String joinWords(List<Word> list) { List<String> toJoin = Lists.newArrayList(); list.forEach(w -> toJoin.add(w.getWord())); return StringUtils.join(toJoin, " "); } public static XMatcher<Word> pos(String tag) { return BeanMatchers.eq(Word.class, "posTag", tag); } public static XMatcher<Word> posIn(String... tags) { return BeanMatchers.in(Word.class, "posTag", ImmutableSet.copyOf(tags)); } public static XMatcher<Word> anyWord() { return Matchers.anything(); } }