package com.constellio.data.dao.services.bigVault;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.TreeMap;
import org.apache.solr.client.solrj.SolrClient;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.client.solrj.SolrServerException;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.client.solrj.util.ClientUtils;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.params.TermVectorParams;
public class JaccardDocumentSorter {
public static final String SIMILARITY_SCORE_FIELD = "sim_score";
private final String idField;
private final String contentField;
private final Map<String, Map<String, Double>> sourceDocTermVector;
private final Map<String, Map<String, Map<String, Double>>> doc2FieldTermVectors = new TreeMap<String, Map<String, Map<String, Double>>>();
private final SolrClient solrClient;
private final JaccardTermVectorSimilarity similarity = new JaccardTermVectorSimilarity();
public JaccardDocumentSorter(SolrClient solrClient, SolrDocument source, String contentField, String idField)
throws SolrServerException, IOException {
this.idField = idField;
this.contentField = contentField;
this.solrClient = solrClient;
sourceDocTermVector = getTermVectors(source);
if (sourceDocTermVector == null)
throw new RuntimeException();
}
private Map<String, Map<String, Double>> getTermVectors(SolrDocument doc)
throws SolrServerException, IOException {
String id = doc.getFieldValue(idField).toString();
Map<String, Map<String, Double>> result = doc2FieldTermVectors.get(id);
if (result == null) {
SolrQuery solrQuery = new SolrQuery(String.format("%s:\"%s\"", idField, ClientUtils.escapeQueryChars(id)));
solrQuery.setRequestHandler("/tvrh");
String fields = String.format("%s", contentField);
// solrQuery.setParam(CommonParams.FL, fields);
solrQuery.setParam(TermVectorParams.TF, "true");
solrQuery.setParam(TermVectorParams.DF, "true");
solrQuery.setParam(TermVectorParams.TF_IDF, "true");
solrQuery.setParam(TermVectorParams.FIELDS, fields);
solrQuery.setRows(1);
QueryResponse response = solrClient.query(solrQuery);
TermVectoreResponse termVectoreResponse = new TermVectoreResponse(response);
Map<String, Map<String, Map<String, Map<String, Double>>>> doc2FieldTermVectors = termVectoreResponse
.getDoc2FieldTermVectors();
if (!doc2FieldTermVectors.containsKey(id))
throw new RuntimeException(
"The " + contentField + " does not support termVectors, please update the solr schema file.");
result = new TreeMap<>();
for (Entry<String, Map<String, Map<String, Double>>> aFieldTermVector : doc2FieldTermVectors.get(id).entrySet())
result.putAll(aFieldTermVector.getValue());
this.doc2FieldTermVectors.put(id, result);
}
return result;
}
public List<SolrDocument> sort(List<SolrDocument> results)
throws SolrServerException, IOException {
List<SolrDocument> sortedResults = new ArrayList<>(results.size());
for (SolrDocument solrDocument : results) {
solrDocument.setField(SIMILARITY_SCORE_FIELD,
new Double(similarity.getSimilarity(sourceDocTermVector, getTermVectors(solrDocument))));
sortedResults.add(solrDocument);
}
Collections.sort(sortedResults, new Comparator<SolrDocument>() {
@Override
public int compare(SolrDocument o1, SolrDocument o2) {
Double score1 = getScore(o1);
Double score2 = getScore(o2);
Double diff = score1 - score2;
if (diff < 0)
return -1;
if (diff > 0)
return +1;
return 0;
}
private Double getScore(SolrDocument o) {
Double score = (Double) o.getFieldValue(SIMILARITY_SCORE_FIELD);
if (score == null)
score = 0D;
return score;
}
});
Collections.reverse(sortedResults);
return sortedResults;
}
}