package org.molgenis.data.semanticsearch.service.impl;
import com.google.common.collect.Sets;
import org.apache.commons.lang3.StringUtils;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.molgenis.data.DataService;
import org.molgenis.data.Entity;
import org.molgenis.data.MolgenisDataAccessException;
import org.molgenis.data.QueryRule;
import org.molgenis.data.QueryRule.Operator;
import org.molgenis.data.meta.model.AttributeMetadata;
import org.molgenis.data.meta.model.EntityType;
import org.molgenis.data.meta.model.EntityTypeMetadata;
import org.molgenis.data.semanticsearch.string.NGramDistanceAlgorithm;
import org.molgenis.data.semanticsearch.string.Stemmer;
import org.molgenis.data.support.QueryImpl;
import org.molgenis.ontology.core.model.OntologyTerm;
import org.molgenis.ontology.core.service.OntologyService;
import org.molgenis.ontology.ic.TermFrequencyService;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.*;
import java.util.stream.Collectors;
import static java.util.Arrays.stream;
import static java.util.Objects.requireNonNull;
import static org.molgenis.data.meta.AttributeType.COMPOUND;
import static org.molgenis.data.meta.model.EntityTypeMetadata.ENTITY_TYPE_META_DATA;
public class SemanticSearchServiceHelper
{
private final TermFrequencyService termFrequencyService;
private final DataService dataService;
private final OntologyService ontologyService;
private final Stemmer stemmer = new Stemmer();
public final static int MAX_NUM_TAGS = 3;
private final static char SPACE_CHAR = ' ';
private final static String COMMA_CHAR = ",";
private final static String CARET_CHARACTER = "^";
private final static String ESCAPED_CARET_CHARACTER = "\\^";
private final static String ILLEGAL_CHARS_REGEX = "[^\\p{L}'a-zA-Z0-9\\.~]+";
@Autowired
public SemanticSearchServiceHelper(DataService dataService, OntologyService ontologyService,
TermFrequencyService termFrequencyService)
{
this.dataService = requireNonNull(dataService);
this.ontologyService = requireNonNull(ontologyService);
this.termFrequencyService = requireNonNull(termFrequencyService);
}
/**
* Create a disMaxJunc query rule based on the given search terms as well as the information from given ontology
* terms
*
* @param ontologyTerms
* @param searchTerms
* @return disMaxJunc queryRule
*/
public QueryRule createDisMaxQueryRuleForAttribute(Set<String> searchTerms, Collection<OntologyTerm> ontologyTerms)
{
List<String> queryTerms = new ArrayList<String>();
if (searchTerms != null)
{
queryTerms.addAll(searchTerms.stream().filter(StringUtils::isNotBlank).map(this::processQueryString)
.collect(Collectors.toList()));
}
// Handle tags with only one ontologyterm
ontologyTerms.stream().filter(ontologyTerm -> !ontologyTerm.getIRI().contains(COMMA_CHAR)).forEach(ot ->
{
queryTerms.addAll(parseOntologyTermQueries(ot));
});
QueryRule disMaxQueryRule = createDisMaxQueryRuleForTerms(queryTerms);
// Handle tags with multiple ontologyterms
ontologyTerms.stream().filter(ontologyTerm -> ontologyTerm.getIRI().contains(COMMA_CHAR)).forEach(ot ->
{
disMaxQueryRule.getNestedRules().add(createShouldQueryRule(ot.getIRI()));
});
return disMaxQueryRule;
}
/**
* Create disMaxJunc query rule based a list of queryTerm. All queryTerms are lower cased and stop words are removed
*
* @param queryTerms
* @return disMaxJunc queryRule
*/
public QueryRule createDisMaxQueryRuleForTerms(List<String> queryTerms)
{
List<QueryRule> rules = new ArrayList<QueryRule>();
queryTerms.stream().filter(StringUtils::isNotEmpty).map(this::escapeCharsExcludingCaretChar).forEach(query ->
{
rules.add(new QueryRule(AttributeMetadata.LABEL, Operator.FUZZY_MATCH, query));
rules.add(new QueryRule(AttributeMetadata.DESCRIPTION, Operator.FUZZY_MATCH, query));
});
QueryRule finalDisMaxQuery = new QueryRule(rules);
finalDisMaxQuery.setOperator(Operator.DIS_MAX);
return finalDisMaxQuery;
}
/**
* Create a disMaxQueryRule with corresponding boosted value
*
* @param queryTerms
* @param boostValue
* @return a disMaxQueryRule with boosted value
*/
public QueryRule createBoostedDisMaxQueryRuleForTerms(List<String> queryTerms, Double boostValue)
{
QueryRule finalDisMaxQuery = createDisMaxQueryRuleForTerms(queryTerms);
if (boostValue != null && boostValue.intValue() != 0)
{
finalDisMaxQuery.setValue(boostValue);
}
return finalDisMaxQuery;
}
/**
* Create a boolean should query for composite tags containing multiple ontology terms
*
* @param multiOntologyTermIri
* @return return a boolean should queryRule
*/
public QueryRule createShouldQueryRule(String multiOntologyTermIri)
{
QueryRule shouldQueryRule = new QueryRule(new ArrayList<QueryRule>());
shouldQueryRule.setOperator(Operator.SHOULD);
for (String ontologyTermIri : multiOntologyTermIri.split(COMMA_CHAR))
{
OntologyTerm ontologyTerm = ontologyService.getOntologyTerm(ontologyTermIri);
List<String> queryTerms = parseOntologyTermQueries(ontologyTerm);
Double termFrequency = getBestInverseDocumentFrequency(queryTerms);
shouldQueryRule.getNestedRules().add(createBoostedDisMaxQueryRuleForTerms(queryTerms, termFrequency));
}
return shouldQueryRule;
}
/**
* Create a list of string queries based on the information collected from current ontologyterm including label,
* synonyms and child ontologyterms
*
* @param ontologyTerm
* @return
*/
public List<String> parseOntologyTermQueries(OntologyTerm ontologyTerm)
{
List<String> queryTerms = getOtLabelAndSynonyms(ontologyTerm).stream().map(this::processQueryString)
.collect(Collectors.<String>toList());
for (OntologyTerm childOt : ontologyService.getChildren(ontologyTerm))
{
double boostedNumber = Math.pow(0.5, ontologyService.getOntologyTermDistance(ontologyTerm, childOt));
getOtLabelAndSynonyms(childOt)
.forEach(synonym -> queryTerms.add(parseBoostQueryString(synonym, boostedNumber)));
}
return queryTerms;
}
/**
* A helper function to collect synonyms as well as label of ontologyterm
*
* @param ontologyTerm
* @return a list of synonyms plus label
*/
public Set<String> getOtLabelAndSynonyms(OntologyTerm ontologyTerm)
{
Set<String> allTerms = Sets.newLinkedHashSet(ontologyTerm.getSynonyms());
allTerms.add(ontologyTerm.getLabel());
return allTerms;
}
public Map<String, String> collectExpandedQueryMap(Set<String> queryTerms, Collection<OntologyTerm> ontologyTerms)
{
Map<String, String> expandedQueryMap = new LinkedHashMap<String, String>();
queryTerms.stream().filter(StringUtils::isNotBlank)
.forEach(queryTerm -> expandedQueryMap.put(Stemmer.cleanStemPhrase(queryTerm), queryTerm));
for (OntologyTerm ontologyTerm : ontologyTerms)
{
if (!ontologyTerm.getIRI().contains(COMMA_CHAR))
{
collectOntologyTermQueryMap(expandedQueryMap, ontologyTerm);
}
else
{
for (String ontologyTermIri : ontologyTerm.getIRI().split(COMMA_CHAR))
{
collectOntologyTermQueryMap(expandedQueryMap, ontologyService.getOntologyTerm(ontologyTermIri));
}
}
}
return expandedQueryMap;
}
public void collectOntologyTermQueryMap(Map<String, String> expanedQueryMap, OntologyTerm ontologyTerm)
{
if (ontologyTerm != null)
{
getOtLabelAndSynonyms(ontologyTerm)
.forEach(term -> expanedQueryMap.put(Stemmer.cleanStemPhrase(term), ontologyTerm.getLabel()));
for (OntologyTerm childOntologyTerm : ontologyService.getChildren(ontologyTerm))
{
getOtLabelAndSynonyms(childOntologyTerm)
.forEach(term -> expanedQueryMap.put(Stemmer.cleanStemPhrase(term), ontologyTerm.getLabel()));
}
}
}
/**
* A helper function that gets identifiers of all the attributes from one EntityType
*
* @param sourceEntityType
* @return
*/
public List<String> getAttributeIdentifiers(EntityType sourceEntityType)
{
Entity EntityTypeEntity = dataService.findOne(ENTITY_TYPE_META_DATA,
new QueryImpl<Entity>().eq(EntityTypeMetadata.FULL_NAME, sourceEntityType.getName()));
if (EntityTypeEntity == null) throw new MolgenisDataAccessException(
"Could not find EntityTypeEntity by the name of " + sourceEntityType.getName());
List<String> attributeIdentifiers = new ArrayList<String>();
recursivelyCollectAttributeIdentifiers(EntityTypeEntity.getEntities(EntityTypeMetadata.ATTRIBUTES),
attributeIdentifiers);
return attributeIdentifiers;
}
private void recursivelyCollectAttributeIdentifiers(Iterable<Entity> attributeEntities,
List<String> attributeIdentifiers)
{
for (Entity attributeEntity : attributeEntities)
{
if (!attributeEntity.getString(AttributeMetadata.TYPE).equals(COMPOUND.toString()))
{
attributeIdentifiers.add(attributeEntity.getString(AttributeMetadata.ID));
}
Iterable<Entity> entities = attributeEntity.getEntities(AttributeMetadata.CHILDREN);
if (entities != null)
{
recursivelyCollectAttributeIdentifiers(entities, attributeIdentifiers);
}
}
}
public List<OntologyTerm> findTags(String description, List<String> ontologyIds)
{
Set<String> searchTerms = removeStopWords(description);
List<OntologyTerm> matchingOntologyTerms = ontologyService
.findOntologyTerms(ontologyIds, searchTerms, MAX_NUM_TAGS);
return matchingOntologyTerms;
}
public String processQueryString(String queryString)
{
return StringUtils.join(removeStopWords(queryString), SPACE_CHAR);
}
public String parseBoostQueryString(String queryString, double boost)
{
return StringUtils.join(removeStopWords(queryString).stream().map(word -> word + CARET_CHARACTER + boost)
.collect(Collectors.toSet()), SPACE_CHAR);
}
public String escapeCharsExcludingCaretChar(String string)
{
return QueryParser.escape(string).replace(ESCAPED_CARET_CHARACTER, CARET_CHARACTER);
}
public Set<String> removeStopWords(String description)
{
Set<String> searchTerms = stream(description.split(ILLEGAL_CHARS_REGEX)).map(String::toLowerCase)
.filter(w -> !NGramDistanceAlgorithm.STOPWORDSLIST.contains(w) && StringUtils.isNotEmpty(w))
.collect(Collectors.toSet());
return searchTerms;
}
private Double getBestInverseDocumentFrequency(List<String> terms)
{
Optional<String> findFirst = terms.stream().sorted(new Comparator<String>()
{
public int compare(String o1, String o2)
{
return Integer.compare(o1.length(), o2.length());
}
}).findFirst();
return findFirst.isPresent() ? termFrequencyService.getTermFrequency(findFirst.get()) : null;
}
}