package com.constellio.data.dao.services.bigVault;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
public class JaccardTermVectorSimilarity {
public static final String TF = "tf";
public static final String TF_IDF = "tf-idf";
private final String freqTag;
private final String scoreTag;
public JaccardTermVectorSimilarity() {
this.freqTag = TF;
this.scoreTag = TF_IDF;
}
public double getSimilarity(Map<String, Map<String, Double>> doc1, Map<String, Map<String, Double>> doc2) {
return intersection(doc1, doc2) / union(doc1, doc2);
}
public double intersection(Map<String, Map<String, Double>> doc1, Map<String, Map<String, Double>> doc2) {
Set<String> intersectTerms = new TreeSet<>(doc1.keySet());
intersectTerms.retainAll(doc2.keySet());
double intersectSize = 0;
for (String term : intersectTerms) {
double minFreq = Math.min(safeGetVal(doc1, term, freqTag), safeGetVal(doc2, term, freqTag));
double minScore = Math.min(safeGetVal(doc1, term, scoreTag), safeGetVal(doc2, term, scoreTag));
intersectSize += minFreq * minScore;
}
return intersectSize;
}
public double union(Map<String, Map<String, Double>> doc1, Map<String, Map<String, Double>> doc2) {
Set<String> intersectTerms = new TreeSet<>(doc1.keySet());
intersectTerms.addAll(doc2.keySet());
double intersectSize = 0;
for (String term : intersectTerms) {
double minFreq = Math.max(safeGetVal(doc1, term, freqTag), safeGetVal(doc2, term, freqTag));
double minScore = Math.max(safeGetVal(doc1, term, scoreTag), safeGetVal(doc2, term, scoreTag));
intersectSize += minFreq * minScore;
}
return intersectSize;
}
private Double safeGetVal(Map<String, Map<String, Double>> doc, String term, String tag) {
Map<String, Double> map = doc.get(term);
if (map == null)
return 0.0;
return map.get(tag);
}
}