package com.formulasearchengine.mathosphere.mathpd;
import com.formulasearchengine.mathosphere.mathpd.cli.FlinkPdCommandConfig;
import com.formulasearchengine.mathosphere.mathpd.contracts.PreprocessedExtractedMathPDDocumentMapper;
import com.formulasearchengine.mathosphere.mathpd.contracts.TextExtractorMapper;
import com.formulasearchengine.mathosphere.mathpd.pojos.ExtractedMathPDDocument;
import com.formulasearchengine.mathosphere.mlp.contracts.CreateCandidatesMapper;
import com.formulasearchengine.mathosphere.mlp.contracts.JsonSerializerMapper;
import com.formulasearchengine.mathosphere.mlp.contracts.TextAnnotatorMapper;
import com.formulasearchengine.mathosphere.mlp.pojos.ParsedWikiDocument;
import com.formulasearchengine.mathosphere.mlp.pojos.WikiDocumentOutput;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.TextInputFormat;
import org.apache.flink.api.java.io.TextOutputFormat;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.tuple.*;
import org.apache.flink.core.fs.FileSystem.WriteMode;
import org.apache.flink.core.fs.Path;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.text.DecimalFormat;
import java.util.*;
public class FlinkPd {
private static final Logger LOGGER = LoggerFactory.getLogger(FlinkPd.class);
private static final int NUMBER_OF_PARTITIONS = -1; // if -1 then partitioning is disabled and it will just be one document merge (all snippets into one doc)
public static boolean IS_MODE_PREPROCESSING = true;
private static DecimalFormat decimalFormat = new DecimalFormat("0.0");
public static void main(String[] args) throws Exception {
FlinkPdCommandConfig config = FlinkPdCommandConfig.from(args);
run(config);
}
/**
* This function takes math pd snippets and converts them to single documents (by merging all snippets belonging to the same document)
*
* @param extractedMathPdSnippets
* @return
*/
private static DataSet<Tuple2<String, ExtractedMathPDDocument>> aggregateSnippetsToPartitions(FlatMapOperator<String, Tuple2<String, ExtractedMathPDDocument>> extractedMathPdSnippets) {
DataSet<Tuple2<String, ExtractedMathPDDocument>> extractedMathPdDocuments = extractedMathPdSnippets
.groupBy(0)
.reduceGroup(new GroupReduceFunction<Tuple2<String, ExtractedMathPDDocument>, Tuple2<String, ExtractedMathPDDocument>>() {
@Override
public void reduce(Iterable<Tuple2<String, ExtractedMathPDDocument>> iterable, Collector<Tuple2<String, ExtractedMathPDDocument>> collector) throws Exception {
final List<Tuple2<String, ExtractedMathPDDocument>> sortedNamesAndSnippets = new ArrayList<>();
for (Tuple2<String, ExtractedMathPDDocument> nameAndSnippet : iterable) {
sortedNamesAndSnippets.add(nameAndSnippet);
}
//LOGGER.warn("sorting {} entries", sortedNamesAndSnippets.size());
Collections.sort(sortedNamesAndSnippets, new Comparator<Tuple2<String, ExtractedMathPDDocument>>() {
@Override
public int compare(Tuple2<String, ExtractedMathPDDocument> o1, Tuple2<String, ExtractedMathPDDocument> o2) {
return o1.f1.getPage().compareTo(o2.f1.getPage());
}
});
final List<List<Tuple2<String, ExtractedMathPDDocument>>> partitions = CollectionUtils.partition(sortedNamesAndSnippets, NUMBER_OF_PARTITIONS);
List<Tuple3<String, String, String>> partitionFirstEntrysTitle = new ArrayList<>();
for (List<Tuple2<String, ExtractedMathPDDocument>> partition : partitions) {
partitionFirstEntrysTitle.add(new Tuple3<>(
partition.get(0).f1.getTitle(),
partition.get(0).f1.getName(),
partition.get(0).f1.getPage()));
}
final List<List<Tuple2<String, ExtractedMathPDDocument>>> overlappingPartitions = CollectionUtils.overlapInPercent(partitions, 0.25, 0.25);
//LOGGER.warn("overlappingPartitions number = {}", overlappingPartitions.size());
int i = 0;
for (List<Tuple2<String, ExtractedMathPDDocument>> overlappingPartition : overlappingPartitions) {
//LOGGER.warn("merging partition with {} entries", overlappingPartition.size());
final Tuple2<String, ExtractedMathPDDocument> mergedPartition = mergeToOne(overlappingPartition);
// we need to overwrite these properties to avoid duplicates later. the duplicates were introduced during creating overlapping partitions.
Tuple3<String, String, String> firstOriginalEntry = partitionFirstEntrysTitle.get(i++);
mergedPartition.setField(firstOriginalEntry.f0, 0);
mergedPartition.f1.setTitle(firstOriginalEntry.f0);
mergedPartition.f1.setName(firstOriginalEntry.f1);
mergedPartition.f1.setPage(firstOriginalEntry.f2);
collector.collect(mergedPartition);
//LOGGER.warn(mergedPartition.f0);
}
}
});
return extractedMathPdDocuments;
}
private static Tuple2<String, ExtractedMathPDDocument> mergeToOne(List<Tuple2<String, ExtractedMathPDDocument>> list) {
final List<HashMap<String, Double>> allHistogramsCi = new ArrayList<>();
final List<HashMap<String, Double>> allHistogramsCn = new ArrayList<>();
final List<HashMap<String, Double>> allHistogramsCsymbol = new ArrayList<>();
final List<HashMap<String, Double>> allHistogramsBvar = new ArrayList<>();
String mainString = null;
ExtractedMathPDDocument mainDoc = null;
for (Tuple2<String, ExtractedMathPDDocument> nameAndSnippet : list) {
final String name = nameAndSnippet.f0;
final ExtractedMathPDDocument snippet = nameAndSnippet.f1;
if (mainDoc == null) {
mainDoc = snippet;
mainString = name;
}
allHistogramsCi.add(snippet.getHistogramCi());
allHistogramsCn.add(snippet.getHistogramCn());
allHistogramsCsymbol.add(snippet.getHistogramCsymbol());
allHistogramsBvar.add(snippet.getHistogramBvar());
}
mainDoc.setHistogramCi(Distances.histogramsPlus(allHistogramsCi));
mainDoc.setHistogramCn(Distances.histogramsPlus(allHistogramsCn));
mainDoc.setHistogramCsymbol(Distances.histogramsPlus(allHistogramsCsymbol));
mainDoc.setHistogramBvar(Distances.histogramsPlus(allHistogramsBvar));
return new Tuple2<>(mainString, mainDoc);
}
private static DataSet<Tuple2<String, ExtractedMathPDDocument>> aggregateSnippets(FlatMapOperator<String, Tuple2<String, ExtractedMathPDDocument>> extractedMathPdSnippets) {
if (NUMBER_OF_PARTITIONS >= 0) {
return aggregateSnippetsToPartitions(extractedMathPdSnippets);
} else if (NUMBER_OF_PARTITIONS == -1) {
return aggregateSnippetsToSingleDocs(extractedMathPdSnippets);
} else {
throw new RuntimeException("illegal state: NUMBER_OF_PARTITIONS");
}
}
/**
* This function takes math pd snippets and converts them to single documents (by merging all snippets belonging to the same document)
*
* @param extractedMathPdSnippets
* @return
*/
private static DataSet<Tuple2<String, ExtractedMathPDDocument>> aggregateSnippetsToSingleDocs(FlatMapOperator<String, Tuple2<String, ExtractedMathPDDocument>> extractedMathPdSnippets) {
DataSet<Tuple2<String, ExtractedMathPDDocument>> extractedMathPdDocuments = extractedMathPdSnippets
.groupBy(0)
.reduce(new ReduceFunction<Tuple2<String, ExtractedMathPDDocument>>() {
@Override
public Tuple2<String, ExtractedMathPDDocument> reduce(Tuple2<String, ExtractedMathPDDocument> t0, Tuple2<String, ExtractedMathPDDocument> t1) throws Exception {
t1.f1.mergeOtherIntoThis(t0.f1);
t1.f1.setText("removed");
//LOGGER.info("merged {} into {}", new Object[]{t1.f0, t0.f0});
return t1;
}
});
return extractedMathPdDocuments;
}
public static void run(FlinkPdCommandConfig config) throws Exception {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
final String preprocessedSourcesFiles = config.getDataset() + "_preprocessed";
String preprocessedRefsFiles = config.getRef() + "_preprocessed";
if (preprocessedRefsFiles.equals(preprocessedSourcesFiles)) {
preprocessedRefsFiles += "2";
}
if (IS_MODE_PREPROCESSING) {
DataSource<String> source = readWikiDump(config, env);
DataSource<String> refs = readRefs(config, env);
final FlatMapOperator<String, Tuple2<String, ExtractedMathPDDocument>> extractedMathPdSnippetsSources = source.flatMap(new TextExtractorMapper(true));
// first, merge all pages of one doc to one doc
DataSet<Tuple2<String, ExtractedMathPDDocument>> extractedMathPdDocumentsSources = aggregateSnippets(extractedMathPdSnippetsSources);
// write to disk
LOGGER.info("writing preprocesssed input to disk at {}", preprocessedRefsFiles);
extractedMathPdDocumentsSources.writeAsFormattedText(preprocessedSourcesFiles,
new TextOutputFormat.TextFormatter<Tuple2<String, ExtractedMathPDDocument>>() {
@Override
public String format(Tuple2<String, ExtractedMathPDDocument> stringExtractedMathPDDocumentTuple2) {
return PreprocessedExtractedMathPDDocumentMapper.getFormattedWritableText(stringExtractedMathPDDocumentTuple2.f1);
}
});
// now for the refs
final FlatMapOperator<String, Tuple2<String, ExtractedMathPDDocument>> extractedMathPdSnippetsRefs = refs.flatMap(new TextExtractorMapper(false));
// first, merge all pages of one doc to one doc
final DataSet<Tuple2<String, ExtractedMathPDDocument>> extractedMathPdDocumentsRefs = aggregateSnippets(extractedMathPdSnippetsRefs);
// write to disk
LOGGER.info("writing preprocesssed refs to disk at {}", preprocessedRefsFiles);
extractedMathPdDocumentsRefs.writeAsFormattedText(preprocessedRefsFiles,
new TextOutputFormat.TextFormatter<Tuple2<String, ExtractedMathPDDocument>>() {
@Override
public String format(Tuple2<String, ExtractedMathPDDocument> stringExtractedMathPDDocumentTuple2) {
return PreprocessedExtractedMathPDDocumentMapper.getFormattedWritableText(stringExtractedMathPDDocumentTuple2.f1);
}
});
} else {
final DataSet<Tuple2<String, ExtractedMathPDDocument>> extractedMathPdDocumentsSources = readPreprocessedFile(preprocessedSourcesFiles, env).flatMap(new PreprocessedExtractedMathPDDocumentMapper());
final DataSet<Tuple2<String, ExtractedMathPDDocument>> extractedMathPdDocumentsRefs = readPreprocessedFile(preprocessedRefsFiles, env).flatMap(new PreprocessedExtractedMathPDDocumentMapper());
GroupReduceOperator<Tuple2<Tuple2<String, ExtractedMathPDDocument>, Tuple3<String, String, Double>>, Tuple2<String, ExtractedMathPDDocument>> extractedMathPDDocsWithTFIDF = null;
DataSet<Tuple7<String, String, Double, Double, Double, Double, Double>> distancesAndSectionPairs =
extractedMathPdDocumentsSources
.map(new MapFunction<Tuple2<String, ExtractedMathPDDocument>, ExtractedMathPDDocument>() {
@Override
public ExtractedMathPDDocument map(Tuple2<String, ExtractedMathPDDocument> stringExtractedMathPDDocumentTuple2) throws Exception {
return stringExtractedMathPDDocumentTuple2.f1;
}
})
.cross(extractedMathPdDocumentsRefs
.map(new MapFunction<Tuple2<String, ExtractedMathPDDocument>, ExtractedMathPDDocument>() {
@Override
public ExtractedMathPDDocument map(Tuple2<String, ExtractedMathPDDocument> stringExtractedMathPDDocumentTuple2) throws Exception {
return stringExtractedMathPDDocumentTuple2.f1;
}
})
)
.map(new MapFunction<Tuple2<ExtractedMathPDDocument, ExtractedMathPDDocument>, Tuple7<String, String, Double, Double, Double, Double, Double>>() {
@Override
public Tuple7<String, String, Double, Double, Double, Double, Double> map(Tuple2<ExtractedMathPDDocument, ExtractedMathPDDocument> extractedMathPDDocumentExtractedMathPDDocumentTuple2) throws Exception {
if (extractedMathPDDocumentExtractedMathPDDocumentTuple2.f0 == null || extractedMathPDDocumentExtractedMathPDDocumentTuple2.f1 == null) {
return null;
}
// Tuple4 contains (if cosine is used, the term distance actually means similarity, i.e.,
// -1=opposite, 0=unrelated, 1=same doc
// 1) total distance (accumulated distance of all others) - makes no sense in case of cosine distance
// 2) numbers
// 3) operators
// 4) identifiers
// 5) bound variables
Tuple4<Double, Double, Double, Double> distanceAllFeatures;
distanceAllFeatures = Distances.distanceRelativeAllFeatures(extractedMathPDDocumentExtractedMathPDDocumentTuple2.f0, extractedMathPDDocumentExtractedMathPDDocumentTuple2.f1);
final Tuple7<String, String, Double, Double, Double, Double, Double> resultLine = new Tuple7<>(
extractedMathPDDocumentExtractedMathPDDocumentTuple2.f0.getId(),
extractedMathPDDocumentExtractedMathPDDocumentTuple2.f1.getId(),
Math.abs(distanceAllFeatures.f0) + Math.abs(distanceAllFeatures.f1) + Math.abs(distanceAllFeatures.f2) + Math.abs(distanceAllFeatures.f3),
distanceAllFeatures.f0,
distanceAllFeatures.f1,
distanceAllFeatures.f2,
distanceAllFeatures.f3
);
return resultLine;
}
})
.sortPartition(1, Order.ASCENDING);
distancesAndSectionPairs.writeAsCsv(config.getOutputDir(), WriteMode.OVERWRITE);
// also merge all partitions together of all document pairs, by taking the min distance in any field
final DataSet<Tuple7<String, String, Double, Double, Double, Double, Double>> minDistancesOfRemergedDocs = distancesAndSectionPairs
.map(new MapFunction<Tuple7<String, String, Double, Double, Double, Double, Double>, Tuple7<String, String, Double, Double, Double, Double, Double>>() {
@Override
public Tuple7<String, String, Double, Double, Double, Double, Double> map(Tuple7<String, String, Double, Double, Double, Double, Double> stringStringDoubleDoubleDoubleDoubleDoubleTuple7) throws Exception {
String id0 = stringStringDoubleDoubleDoubleDoubleDoubleTuple7.f0;
String id1 = stringStringDoubleDoubleDoubleDoubleDoubleTuple7.f1;
id0 = id0.substring(0, id0.lastIndexOf("/"));
id1 = id1.substring(0, id1.lastIndexOf("/"));
return new Tuple7<>(id0, id1, stringStringDoubleDoubleDoubleDoubleDoubleTuple7.f2,
stringStringDoubleDoubleDoubleDoubleDoubleTuple7.f3,
stringStringDoubleDoubleDoubleDoubleDoubleTuple7.f4,
stringStringDoubleDoubleDoubleDoubleDoubleTuple7.f5,
stringStringDoubleDoubleDoubleDoubleDoubleTuple7.f6);
}
})
.groupBy(0, 1)
.reduceGroup(new GroupReduceFunction<Tuple7<String, String, Double, Double, Double, Double, Double>, Tuple7<String, String, Double, Double, Double, Double, Double>>() {
@Override
public void reduce(Iterable<Tuple7<String, String, Double, Double, Double, Double, Double>> iterable, Collector<Tuple7<String, String, Double, Double, Double, Double, Double>> collector) throws Exception {
double f2 = Double.MAX_VALUE, f3 = Double.MAX_VALUE, f4 = Double.MAX_VALUE, f5 = Double.MAX_VALUE, f6 = Double.MAX_VALUE;
String s0 = null, s1 = null;
for (Tuple7<String, String, Double, Double, Double, Double, Double> cur : iterable) {
if (s0 == null)
s0 = cur.f0;
if (s1 == null)
s1 = cur.f1;
f2 = Math.min(f2, cur.f2);
f3 = Math.min(f2, cur.f3);
f4 = Math.min(f2, cur.f4);
f5 = Math.min(f2, cur.f5);
f6 = Math.min(f2, cur.f6);
}
collector.collect(new Tuple7<>(
s0, s1, f2, f3, f4, f5, f6
));
}
});
minDistancesOfRemergedDocs.writeAsCsv(config.getOutputDir() + "_remergedbymindist", WriteMode.OVERWRITE);
// we can now use the distances and section pairs dataset to aggregate the distances on document level in distance bins
//noinspection Convert2Lambda
DataSet binnedDistancesForPairs =
distancesAndSectionPairs
.reduceGroup(new GroupReduceFunction<
Tuple7<String, String, Double, Double, Double, Double, Double>,
Tuple5<String, String, Double, Double, Double>>() {
@Override
public void reduce(Iterable<Tuple7<String, String, Double, Double, Double, Double, Double>> iterable, Collector<Tuple5<String, String, Double, Double, Double>> collector) throws Exception {
// histogram will contain as a key a tuple2 of the names of the two documents from the pair; and the bin
// the value will be the frequency of that bin in that pair of documents
final HashMap<Tuple4<String, String, Double, Double>, Double> histogramPairOfNameAndBinWithFrequency = new HashMap<>();
final HashMap<Tuple2<String, String>, Double> histogramPairOfNameWithFrequency = new HashMap<>();
for (Tuple7<String, String, Double, Double, Double, Double, Double> curPairWithDistances : iterable) {
final String id0 = curPairWithDistances.f0;
final String id1 = curPairWithDistances.f1;
final String name0 = ExtractedMathPDDocument.getNameFromId(id0);
final String name1 = ExtractedMathPDDocument.getNameFromId(id1);
double distance = curPairWithDistances.f2 / 4.0; // take the accumulated distance and normalize it
// the key3
final Tuple4<String, String, Double, Double> key =
new Tuple4<>(
name0,
name1,
getBinBoundary(distance, 0.2, true),
getBinBoundary(distance, 0.2, false));
final Tuple2<String, String> keyName = new Tuple2<String, String>(name0, name1);
// look up if something has been stored under this key
Double frequencyOfCurKey = histogramPairOfNameAndBinWithFrequency.getOrDefault(key, 0.0);
histogramPairOfNameAndBinWithFrequency.put(key, frequencyOfCurKey + 1.0);
// also update the pair's total frequency
histogramPairOfNameWithFrequency.put(keyName, histogramPairOfNameWithFrequency.getOrDefault(keyName, 0.0) + 1.0);
}
for (Tuple4<String, String, Double, Double> key : histogramPairOfNameAndBinWithFrequency.keySet()) {
collector.collect(new Tuple5<>(key.f0, key.f1, key.f2, key.f3, histogramPairOfNameAndBinWithFrequency.get(key) / histogramPairOfNameWithFrequency.get(new Tuple2<>(key.f0, key.f1))));
}
}
})
.sortPartition(0, Order.ASCENDING)
.sortPartition(1, Order.ASCENDING);
binnedDistancesForPairs.writeAsCsv(config.getOutputDir() + "_binned", WriteMode.OVERWRITE);
}
env.execute(String.format("MathPD(IS_MODE_PREPROCESSING=%b)", config.isPreProcessingMode()));
}
private static double getBinBoundary(double value, double binWidth, boolean isLower) {
double flooredDivision = Math.floor(value / binWidth);
double binBoundary;
if (isLower)
binBoundary = binWidth * flooredDivision;
else
binBoundary = binWidth * (flooredDivision + 1);
return Double.valueOf(decimalFormat.format(binBoundary));
}
public static DataSource<String> readWikiDump(FlinkPdCommandConfig config, ExecutionEnvironment env) {
Path filePath = new Path(config.getDataset());
TextInputFormat inp = new TextInputFormat(filePath);
inp.setCharsetName("UTF-8");
inp.setDelimiter("</ARXIVFILESPLIT>");
return env.readFile(inp, config.getDataset());
}
public static DataSource<String> readRefs(FlinkPdCommandConfig config, ExecutionEnvironment env) {
Path filePath = new Path(config.getRef());
TextInputFormat inp = new TextInputFormat(filePath);
inp.setCharsetName("UTF-8");
inp.setDelimiter("</ARXIVFILESPLIT>");
return env.readFile(inp, config.getRef());
}
public static DataSource<String> readPreprocessedFile(String pathname, ExecutionEnvironment env) {
Path filePath = new Path(pathname);
TextInputFormat inp = new TextInputFormat(filePath);
inp.setCharsetName("UTF-8");
// env.read
return env.readFile(inp, pathname);
}
public String runFromText(FlinkPdCommandConfig config, String input) throws Exception {
final JsonSerializerMapper<Object> serializerMapper = new JsonSerializerMapper<>();
return serializerMapper.map(outDocFromText(config, input));
}
public WikiDocumentOutput outDocFromText(FlinkPdCommandConfig config, String input) throws Exception {
final TextAnnotatorMapper textAnnotatorMapper = new TextAnnotatorMapper(config);
textAnnotatorMapper.open(null);
final CreateCandidatesMapper candidatesMapper = new CreateCandidatesMapper(config);
final ParsedWikiDocument parsedWikiDocument = textAnnotatorMapper.parse(input);
return candidatesMapper.map(parsedWikiDocument);
}
}