package com.formulasearchengine.mathosphere.mlp; import com.formulasearchengine.mathosphere.mlp.cli.MachineLearningDefinienClassifierConfig; import com.formulasearchengine.mathosphere.mlp.ml.WekaClassifier; import com.formulasearchengine.mathosphere.mlp.pojos.*; import com.formulasearchengine.mathosphere.mlp.text.SimpleFeatureExtractorMapper; import com.fasterxml.jackson.databind.ObjectMapper; import com.formulasearchengine.mathosphere.mlp.cli.EvalCommandConfig; import com.formulasearchengine.mathosphere.mlp.cli.FlinkMlpCommandConfig; import com.formulasearchengine.mathosphere.mlp.contracts.*; import com.formulasearchengine.mathosphere.mlp.text.WikiTextUtils; import com.formulasearchengine.mathosphere.utils.Util; import com.google.common.collect.HashMultiset; import com.google.common.collect.Multiset; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVPrinter; import org.apache.commons.csv.CSVRecord; import org.apache.flink.api.common.functions.GroupReduceFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.io.TextInputFormat; import org.apache.flink.api.java.operators.DataSource; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.core.fs.FileSystem.WriteMode; import org.apache.flink.core.fs.Path; import org.apache.flink.util.Collector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.*; import java.util.*; import java.util.stream.Collectors; public class FlinkMlpRelationFinder { private static final Logger LOGGER = LoggerFactory.getLogger(FlinkMlpRelationFinder.class); public static void main(String[] args) throws Exception { FlinkMlpCommandConfig config = FlinkMlpCommandConfig.from(args); run(config); } public static void run(FlinkMlpCommandConfig config) throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSource<String> source = readWikiDump(config, env); DataSet<ParsedWikiDocument> documents = source.flatMap(new TextExtractorMapper()) .map(new TextAnnotatorMapper(config)); DataSet<WikiDocumentOutput> result = documents.map(new CreateCandidatesMapper(config)); result.map(new JsonSerializerMapper<>()) .writeAsText(config.getOutputDir(), WriteMode.OVERWRITE); //int cores = Runtime.getRuntime().availableProcessors(); //env.setParallelism(1); // rounds down final int parallelism = config.getParallelism(); if (parallelism > 0) { env.setParallelism(parallelism); } env.execute("Relation Finder"); } public static DataSource<String> readWikiDump(FlinkMlpCommandConfig config, ExecutionEnvironment env) { Path filePath = new Path(config.getDataset()); TextInputFormat inp = new TextInputFormat(filePath); inp.setCharsetName("UTF-8"); inp.setDelimiter("</page>"); return env.readFile(inp, config.getDataset()); } public static void evaluate(EvalCommandConfig config) throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSource<String> source = readWikiDump(config, env); final MapFunction<ParsedWikiDocument, WikiDocumentOutput> candidatesMapper; if (config.isPatternMatcher()) { candidatesMapper = new PatternMatcherMapper(); } else { candidatesMapper = new CreateCandidatesMapper(config); } DataSet<ParsedWikiDocument> documents = source.flatMap(new TextExtractorMapper()) .map(new TextAnnotatorMapper(config)); final File file = new File(config.getQueries()); ObjectMapper mapper = new ObjectMapper(); List userData = mapper.readValue(file, List.class); Map<String, Object> gold = new HashMap<>(); for (Object o : userData) { final Map entry = (Map) o; Map formula = (Map) entry.get("formula"); gold.put((String) formula.get("title"), o); } final File ndFile = new File(config.getNdFile()); List ndList = mapper.readValue(ndFile, List.class); Map<String, Object> ndData = new HashMap<>(); for (Object o : ndList) { final Map entry = (Map) o; ndData.put(((String) entry.get("document_title")).replaceAll(" ", "_"), o); } GroupReduceFunction<ParsedWikiDocument, String> reduceFunction = new GroupReduceFunction<ParsedWikiDocument, String>() { @Override public void reduce(Iterable<ParsedWikiDocument> iterable, Collector<String> collector) throws Exception { Multiset<String> tpOverall = HashMultiset.create(); Multiset<String> fnOverall = HashMultiset.create(); Multiset<String> fpOverall = HashMultiset.create(); Multiset<Relation> tpRelOverall = HashMultiset.create(); Integer fnRelOverallCnt = 0; Multiset<Relation> fpRelOverall = HashMultiset.create(); for (ParsedWikiDocument parsedWikiDocument : iterable) { String title = parsedWikiDocument.getTitle().replaceAll(" ", "_"); try { Map goldElement = (Map) gold.get(title); Map formula = (Map) goldElement.get("formula"); final Integer formulaId = Integer.parseInt((String) formula.get("fid")); final String tex = (String) formula.get("math_inputtex"); final String qId = (String) formula.get("qID"); final MathTag seed = parsedWikiDocument.getFormulas().stream() .filter(f -> f.getMarkUpType().equals(WikiTextUtils.MathMarkUpType.LATEX)).collect(Collectors.toList()) .get(formulaId); //WikiTextUtils.getLatexFormula(parsedWikiDocument, formulaId); if (!seed.getContent().equals(tex)) { LOGGER.error("PROBLEM WITH" + title); LOGGER.error(seed.getContent()); LOGGER.error(tex); throw new Exception("Invalid numbering."); } final WikiDocumentOutput wikiDocumentOutput = candidatesMapper.map(parsedWikiDocument); List<Relation> relations = wikiDocumentOutput.getRelations(); final Set<String> real = seed.getIdentifiers(config).elementSet(); //only keep identifiers that have a definition final Map definitions = (Map) goldElement.get("definitions"); final Set expected = definitions.keySet(); Set<String> tp = new HashSet<>(expected); Set<String> fn = new HashSet<>(expected); Set<String> fp = new HashSet<>(real); tp.retainAll(real); fn.removeAll(real); fp.removeAll(expected); tpOverall.addAll(tp); fnOverall.addAll(fn); fpOverall.addAll(fp); LOGGER.info("https://en.formulasearchengine.com/wiki/" + title + "#math." + formula.get("oldId") + "." + formulaId); if (config.getNamespace()) { getNamespaceData(title, relations); } //remove identifiers that are not in the gold standard -> these were errors of the identifier extraction. relations.removeIf(r -> !expected.contains(r.getIdentifier())); Collections.sort(relations, Relation::compareToName); removeDuplicates(definitions, relations); writeRelevanceTemplates(qId, relations); Util.writeExtractedDefinitionsAsCsv(config.getOutputDir() + "/extraction.csv", qId, wikiDocumentOutput.getTitle().replaceAll("\\s", "_"), relations); Map<Tuple2<String, String>, Integer> references = new HashMap<>(); if (config.getRelevanceFolder() != null) { final FileReader relevance = new FileReader(config.getRelevanceFolder() + "/q" + qId + ".csv"); Iterable<CSVRecord> records = CSVFormat.RFC4180.parse(relevance); for (CSVRecord record : records) { String identifier = record.get(0); if (identifier.length() > 0) { String definition = record.get(1); Integer relevanceRanking = Integer.valueOf(record.get(2)); references.put(new Tuple2<>(identifier, definition), relevanceRanking); } } int tpcnt = 0; for (Relation relation : relations) { Integer score = references.get(new Tuple2<>(relation.getIdentifier(), relation.getDefinition())); if (score != null && score >= config.getLevel()) { tpRelOverall.add(relation); LOGGER.info("tp: " + relation.getIdentifier() + ", " + relation.getDefinition()); tpcnt++; } else { fpRelOverall.add(relation); LOGGER.info("fp: " + relation.getIdentifier() + ", " + relation.getDefinition()); } } fnRelOverallCnt += (expected.size() - tpcnt); } } catch (Exception e) { e.printStackTrace(); LOGGER.info("Problem with " + title); } } LOGGER.info("Overall identifier evaluation"); LOGGER.info("fp:" + fpOverall.size()); LOGGER.info("fn:" + fnOverall.size()); LOGGER.info("tp:" + tpOverall.size()); LOGGER.info("Overall definition evaluation - by this method, better use evaluation in Evaluation package."); LOGGER.info("fp=" + fpRelOverall.size() + "; fn=" + fnRelOverallCnt + "; tp=" + tpRelOverall.size()); LOGGER.info(fpRelOverall.toString()); } public void removeDuplicates(Map definitions, List<Relation> relations) { String lastDef = ""; String lastIdent = ""; final Iterator<Relation> iterator = relations.iterator(); while (iterator.hasNext()) { final Relation relation = iterator.next(); final List<String> refList = getDefiniens(definitions, relation); final String definition = relation.getDefinition().replaceAll("(\\[\\[|\\]\\])", "").replaceAll("_", " ").trim().toLowerCase(); if (refList.contains(definition)) { relation.setRelevance(2); } if (lastIdent.compareTo(relation.getIdentifier()) + relation.getDefinition().compareToIgnoreCase(lastDef) == 0) { iterator.remove(); } lastDef = relation.getDefinition(); lastIdent = relation.getIdentifier(); } } public void writeRelevanceTemplates(String qId, List<Relation> relations) throws IOException { if (config.getOutputDir() != null) { final File output = new File(config.getOutputDir() + "/q" + qId + ".csv"); output.createNewFile(); OutputStreamWriter w = new FileWriter(output); CSVPrinter printer = CSVFormat.DEFAULT.withRecordSeparator("\n").print(w); for (Relation relation : relations) { String sScore; if (relation.getRelevance() == null) { sScore = ""; } else { sScore = String.valueOf(relation.getRelevance()); } String[] out = new String[]{relation.getIdentifier(), relation.getDefinition(), sScore}; printer.printRecord(out); } w.flush(); w.close(); } } public void getNamespaceData(String title, List<Relation> relations) { final Map nd = (Map) ndData.get(title); if (nd != null) { List relNS = (List) nd.get("namespace_relations"); if (relNS != null) for (Object o : relNS) { Relation rel = new Relation(o); relations.add(rel); } } } }; env.setParallelism(1); documents.reduceGroup(reduceFunction).print(); } public static List<String> getDefiniens(Map definitions, Relation relation) { List<String> result = new ArrayList<>(); List definiens = (List) definitions.get(relation.getIdentifier()); for (Object definien : definiens) { if (definien instanceof Map) { Map<String, String> var = (Map) definien; for (Map.Entry<String, String> stringStringEntry : var.entrySet()) { // there is only one entry final String def = stringStringEntry.getValue().trim().replaceAll("\\s*\\(.*?\\)$", "").toLowerCase(); result.add(def); } } else { result.add(((String) definien).toLowerCase()); } } return result; } public String runFromText(FlinkMlpCommandConfig config, String input) throws Exception { final JsonSerializerMapper<Object> serializerMapper = new JsonSerializerMapper<>(); return serializerMapper.map(outDocFromText(config, input)); } public WikiDocumentOutput outDocFromText(FlinkMlpCommandConfig config, String input) throws Exception { final TextAnnotatorMapper textAnnotatorMapper = new TextAnnotatorMapper(config); textAnnotatorMapper.open(null); final CreateCandidatesMapper candidatesMapper = new CreateCandidatesMapper(config); final ParsedWikiDocument parsedWikiDocument = textAnnotatorMapper.parse(input); return candidatesMapper.map(parsedWikiDocument); } }