package com.formulasearchengine.mathosphere.mathpd; import com.formulasearchengine.mathmltools.mml.CMMLInfo; import com.formulasearchengine.mathmltools.xmlhelper.NonWhitespaceNodeList; import com.formulasearchengine.mathmltools.xmlhelper.XMLHelper; import com.formulasearchengine.mathosphere.mathpd.distances.earthmover.EarthMoverDistanceWrapper; import com.formulasearchengine.mathosphere.mathpd.distances.earthmover.JFastEMD; import com.formulasearchengine.mathosphere.mathpd.distances.earthmover.Signature; import com.formulasearchengine.mathosphere.mathpd.pojos.ArxivDocument; import com.formulasearchengine.mathosphere.mathpd.pojos.ExtractedMathPDDocument; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.flink.api.java.tuple.Tuple4; import org.w3c.dom.Node; import org.w3c.dom.NodeList; import javax.xml.parsers.ParserConfigurationException; import javax.xml.transform.TransformerException; import javax.xml.xpath.XPathExpressionException; import java.io.IOException; import java.text.DecimalFormat; import java.util.*; /** * Created by Felix Hamborg <felixhamborg@gmail.com> on 05.12.16. */ public class Distances { private static final Log LOG = LogFactory.getLog(Distances.class); private static final DecimalFormat decimalFormat = new DecimalFormat("#.###"); /** * probably only makes sense to compute this on CI * * @param h1 * @param h2 * @return */ public static double computeEarthMoverAbsoluteDistance(Map<String, Double> h1, Map<String, Double> h2) { Signature s1 = EarthMoverDistanceWrapper.histogramToSignature(h1); Signature s2 = EarthMoverDistanceWrapper.histogramToSignature(h2); return JFastEMD.distance(s1, s2, 0.0); } public static double computeRelativeDistance(Map<String, Double> h1, Map<String, Double> h2) { int totalNumberOfElements = 0; for (Double frequency : h1.values()) { totalNumberOfElements += frequency; } for (Double frequency : h2.values()) { totalNumberOfElements += frequency; } if (totalNumberOfElements == 0) { return 0.0; } final double absoluteDistance = computeAbsoluteDistance(h1, h2); return absoluteDistance / totalNumberOfElements; } /** * compares two histograms and returns the accumulated number of differences (absolute) * * @param h1 * @param h2 * @return */ public static double computeAbsoluteDistance(Map<String, Double> h1, Map<String, Double> h2) { double distance = 0; final Set<String> keySet = new HashSet(); keySet.addAll(h1.keySet()); keySet.addAll(h2.keySet()); for (String key : keySet) { double v1 = 0.0; double v2 = 0.0; if (h1.get(key) != null) { v1 = h1.get(key); } if (h2.get(key) != null) { v2 = h2.get(key); } distance += Math.abs(v1 - v2); } return distance; } /** * Returns a map of the names and their accumulated frequency of the given content-elements (that could be identifiers, numbers, or operators) * * @param nodes * @return */ protected static HashMap<String, Double> contentElementsToHistogram(NodeList nodes) { final HashMap<String, Double> histogram = new HashMap<>(); for (int i = 0; i < nodes.getLength(); i++) { Node node = nodes.item(i); String contentElementName = node.getTextContent().trim(); // increment frequency by 1 histogram.put(contentElementName, histogram.getOrDefault(contentElementName, 0.0) + 1.0); } return histogram; } /** * Adds all elements from all histogram * * @return */ public static HashMap<String, Double> histogramsPlus(List<HashMap<String, Double>> histograms) { return histogramsPlus(histograms.toArray(new HashMap[histograms.size()])); } /** * Adds all elements from all histogram * * @return */ public static HashMap<String, Double> histogramsPlus(HashMap<String, Double>... histograms) { switch (histograms.length) { case 0: throw new IllegalArgumentException("histograms.length=" + histograms.length + "; needs to be >= 2"); // return null; case 1: return histograms[0]; } final Set<String> mergedKeys = new HashSet<>(); for (HashMap<String, Double> histogram : histograms) { mergedKeys.addAll(histogram.keySet()); } final HashMap<String, Double> mergedHistogram = new HashMap<>(); for (String key : mergedKeys) { double value = 0.0; for (HashMap<String, Double> histogram : histograms) { value += histogram.getOrDefault(key, 0.0); } mergedHistogram.put(key, value); } return mergedHistogram; } /** * Returns an absolute histogram of the whole document d with all elements that match tagname. The key in the histogram is the element's name. * * @param d * @param tagName * @return * @throws XPathExpressionException * @throws ParserConfigurationException * @throws TransformerException * @throws IOException */ public static HashMap<String, Double> getDocumentHistogram(ArxivDocument d, String tagName, NonWhitespaceNodeList allMathTagsOfDOc) throws XPathExpressionException, ParserConfigurationException, TransformerException, IOException { LOG.debug("getDocumentHistogram(" + d.title + ", " + tagName + ")"); HashMap<String, Double> mergedHistogram = new HashMap<>(); final NonWhitespaceNodeList allMathTags = (allMathTagsOfDOc != null) ? allMathTagsOfDOc : d.getMathTags(); for (int i = 0; i < allMathTags.getLength(); i++) { final Node mathTag = allMathTags.item(i); // this hack is necessary, as the converter that generates StrictCMML does not work correctly for CN, e.g., the number 3 is converted into a cn 10 as a base and a cs 3 as the actual number. if (tagName.equals("cn")) { mergedHistogram = histogramsPlus(mergedHistogram, cmmlNodeToHistrogram(mathTag, tagName)); } else { final CMMLInfo curStrictCmml = new CMMLInfo(mathTag).toStrictCmml(); LOG.trace(curStrictCmml.toString()); mergedHistogram = histogramsPlus(mergedHistogram, strictCmmlInfoToHistogram(curStrictCmml, tagName)); } } // cleanup cleanupHistogram(tagName, mergedHistogram); LOG.debug("getDocumentHistogram(" + d.title + ", " + tagName + "): " + mergedHistogram); return mergedHistogram; } /** * converts strict content math ml to a histogram for the given tagname, e.g., ci * * @param strictCmml * @param tagName * @return */ private static HashMap<String, Double> strictCmmlInfoToHistogram(CMMLInfo strictCmml, String tagName) { final NodeList elements = strictCmml.getElementsByTagName(tagName); return contentElementsToHistogram(elements); } /** * converts content math ml to a histogram for the given tagname, e.g., cn * * @param node * @param tagName * @return */ private static HashMap<String, Double> cmmlNodeToHistrogram(Node node, String tagName) throws XPathExpressionException { final NodeList elements = XMLHelper.getElementsB(node, "*//*:" + tagName); return contentElementsToHistogram(elements); } public static Tuple4<Double, Double, Double, Double> distanceAbsoluteAllFeatures(ExtractedMathPDDocument f0, ExtractedMathPDDocument f1) { final double absoluteDistanceContentNumbers = computeAbsoluteDistance(f0.getHistogramCn(), f1.getHistogramCn()); final double absoluteDistanceContentOperators = computeAbsoluteDistance(f0.getHistogramCsymbol(), f1.getHistogramCsymbol()); final double absoluteDistanceContentIdentifiers = computeAbsoluteDistance(f0.getHistogramCi(), f1.getHistogramCi()); final double absoluteDistanceBoundVariables = computeAbsoluteDistance(f0.getHistogramBvar(), f1.getHistogramBvar()); LOG.debug("the following distances should all be 0"); LOG.debug(getDocDescription(f0, f1) + "CN " + decimalFormat.format(absoluteDistanceContentNumbers)); LOG.debug(getDocDescription(f0, f1) + "CSYMBOL " + decimalFormat.format(absoluteDistanceContentOperators)); LOG.debug(getDocDescription(f0, f1) + "CI " + decimalFormat.format(absoluteDistanceContentIdentifiers)); LOG.debug(getDocDescription(f0, f1) + "BVAR " + decimalFormat.format(absoluteDistanceBoundVariables)); return new Tuple4<>(absoluteDistanceContentNumbers, absoluteDistanceContentOperators, absoluteDistanceContentIdentifiers, absoluteDistanceBoundVariables); } public static double computeCosineDistance(HashMap<String, Double> h1, HashMap<String, Double> h2) { final Set<String> mergedKeys = new HashSet<>(h1.keySet()); mergedKeys.addAll(h2.keySet()); // if both histograms are empty, they are same if (h1.size() + h2.size() == 0) { return -10.0; // tmp value for development, TODO, replace with 1.0 once finished. } // if at least one histogram is not empty but no keys are shared, the documents will be completely different if (mergedKeys.isEmpty()) { return 0.0; } // https://en.wikipedia.org/wiki/Cosine_similarity double numerator = 0.0; for (String key : mergedKeys) { numerator += (h1.getOrDefault(key, 0.0) * h2.getOrDefault(key, 0.0)); } double denominator1 = 0.0; for (String key : h1.keySet()) { double value = h1.get(key); denominator1 += (value * value); } denominator1 = Math.sqrt(denominator1); double denominator2 = 0.0; for (String key : h2.keySet()) { double value = h2.get(key); denominator2 += (value * value); } denominator2 = Math.sqrt(denominator2); return numerator / (denominator1 * denominator2); } public static Tuple4<Double, Double, Double, Double> distanceCosineAllFeatures(ExtractedMathPDDocument f0, ExtractedMathPDDocument f1) { final double cosineDistanceContentNumbers = computeCosineDistance(f0.getHistogramCn(), f1.getHistogramCn()); final double cosineDistanceContentOperators = computeCosineDistance(f0.getHistogramCsymbol(), f1.getHistogramCsymbol()); final double cosineDistanceContentIdentifiers = computeCosineDistance(f0.getHistogramCi(), f1.getHistogramCi()); final double cosineDistanceBoundVariables = computeCosineDistance(f0.getHistogramBvar(), f1.getHistogramBvar()); LOG.debug(getDocDescription(f0, f1) + "CN " + decimalFormat.format(cosineDistanceContentNumbers)); LOG.debug(getDocDescription(f0, f1) + "CSYMBOL " + decimalFormat.format(cosineDistanceContentOperators)); LOG.debug(getDocDescription(f0, f1) + "CI " + decimalFormat.format(cosineDistanceContentIdentifiers)); LOG.debug(getDocDescription(f0, f1) + "BVAR " + decimalFormat.format(cosineDistanceBoundVariables)); return new Tuple4<>(cosineDistanceContentNumbers, cosineDistanceContentOperators, cosineDistanceContentIdentifiers, cosineDistanceBoundVariables); } public static Tuple4<Double, Double, Double, Double> distanceRelativeAllFeatures(ExtractedMathPDDocument f0, ExtractedMathPDDocument f1) { final double relativeDistanceContentNumbers = computeRelativeDistance(f0.getHistogramCn(), f1.getHistogramCn()); final double relativeDistanceContentOperators = computeRelativeDistance(f0.getHistogramCsymbol(), f1.getHistogramCsymbol()); final double relativeDistanceContentIdentifiers = computeRelativeDistance(f0.getHistogramCi(), f1.getHistogramCi()); final double relativeDistanceBoundVariables = computeRelativeDistance(f0.getHistogramBvar(), f1.getHistogramBvar()); LOG.debug(getDocDescription(f0, f1) + "CN " + decimalFormat.format(relativeDistanceContentNumbers)); LOG.debug(getDocDescription(f0, f1) + "CSYMBOL " + decimalFormat.format(relativeDistanceContentOperators)); LOG.debug(getDocDescription(f0, f1) + "CI " + decimalFormat.format(relativeDistanceContentIdentifiers)); LOG.debug(getDocDescription(f0, f1) + "BVAR " + decimalFormat.format(relativeDistanceBoundVariables)); return new Tuple4<>(relativeDistanceContentNumbers, relativeDistanceContentOperators, relativeDistanceContentIdentifiers, relativeDistanceBoundVariables); } private static String getDocDescription(ExtractedMathPDDocument f0, ExtractedMathPDDocument f1) { return "{" + f0.getTitle() + "; " + f1.getTitle() + "} "; } /** * this cleanup is necessary due to errors in the xslt conversion script (contentmathmml to strict cmml) * * @param tagName * @param histogram */ private static void cleanupHistogram(String tagName, HashMap<String, Double> histogram) { switch (tagName) { case "csymbol": histogram.remove("based_integer"); for (String key : ValidCSymbols.VALID_CSYMBOLS) { histogram.remove(key); } break; case "ci": histogram.remove("integer"); break; case "cn": Set<String> toberemovedKeys = new HashSet<>(); for (String key : histogram.keySet()) { if (!isNumeric(key)) { toberemovedKeys.add(key); } } // now we can remove the keys for (String key : toberemovedKeys) { histogram.remove(key); } break; } } private static boolean isNumeric(String str) { return str.matches("-?\\d+(\\.\\d+)?"); //match a number with optional '-' and decimal. } }