/**
* Copyright (c) 2016 Lemur Consulting Ltd.
* <p/>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p/>
* http://www.apache.org/licenses/LICENSE-2.0
* <p/>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.flax.biosolr.ontology.search.elasticsearch;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.cardinality.Cardinality;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHits;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.co.flax.biosolr.ontology.api.Document;
import uk.co.flax.biosolr.ontology.config.ElasticSearchConfiguration;
import uk.co.flax.biosolr.ontology.search.DocumentSearch;
import uk.co.flax.biosolr.ontology.search.ResultsList;
import uk.co.flax.biosolr.ontology.search.SearchEngineException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Created by mlp on 17/02/16.
*
* @author mlp
*/
public class ElasticDocumentSearch extends ElasticSearchEngine implements DocumentSearch {
private static final Logger LOGGER = LoggerFactory.getLogger(ElasticDocumentSearch.class);
private static final String GROUP_FIELD = "study_id";
private static final String COUNT_AGGREGATION = "numFound";
private static final String HITS_AGGREGATION = "study";
private static final String SCORE_AGGREGATION = "topScore";
private static final String[] DEFAULT_FIELDS = new String[]{
"title", "first_author", "publication", "efo_uri.label"
};
private static final List<String> ANNOTATED_FIELDS = new ArrayList<>();
static {
ANNOTATED_FIELDS.add("label");
ANNOTATED_FIELDS.add("child_labels");
ANNOTATED_FIELDS.add("parent_labels");
}
public ElasticDocumentSearch(Client client, ElasticSearchConfiguration config) {
super(client, config);
}
@Override
public ResultsList<Document> searchDocuments(String term, int start, int rows,
List<String> additionalFields, List<String> filters) throws SearchEngineException {
// Build the query
MultiMatchQueryBuilder qb = QueryBuilders.multiMatchQuery(term, DEFAULT_FIELDS)
.minimumShouldMatch("2<25%");
if (additionalFields != null && additionalFields.size() > 0) {
List<String> parsedAdditional = parseAdditionalFields(additionalFields);
parsedAdditional.forEach(qb::field);
}
TopHitsBuilder topHitsBuilder = AggregationBuilders.topHits(HITS_AGGREGATION)
.setFrom(0)
.setSize(1);
/* Build the terms aggregation, since we need a result set grouped by study ID.
* The "top_score" sub-agg allows us to sort by the top score of the results;
* the topHits sub-agg actually pulls back the record data, returning just the first
* hit in the aggregation.
* Note that we have to get _all_ rows up to and including the last required, annoyingly. */
AggregationBuilder termsAgg = AggregationBuilders.terms(HITS_AGGREGATION)
.field(GROUP_FIELD)
.order(Terms.Order.aggregation(SCORE_AGGREGATION, false))
.size(start + rows)
.subAggregation(
AggregationBuilders.max(SCORE_AGGREGATION)
.script(new Script("_score", ScriptService.ScriptType.INLINE, "expression", null)))
.subAggregation(topHitsBuilder);
// Build the actual search request, including another aggregation to get
// the number of unique study IDs returned.
SearchRequestBuilder srb = getClient().prepareSearch(getIndexName())
.setTypes(getDocumentType())
.setQuery(qb)
.setSize(0)
.addAggregation(termsAgg)
.addAggregation(AggregationBuilders.cardinality(COUNT_AGGREGATION).field(GROUP_FIELD));
LOGGER.debug("ES Query: {}", srb.toString());
SearchResponse response = srb.execute().actionGet();
// Handle the response
long total = ((Cardinality)(response.getAggregations().get(COUNT_AGGREGATION))).getValue();
List<Document> docs;
if (total == 0) {
docs = new ArrayList<>();
} else {
// Build a map - need to look up annotation data separately.
// This is because it's not in _source, and the fields() method
// is not visible for a TopHitsBuilder.
Map<String, Document> documentMap = new LinkedHashMap<>(rows);
ObjectMapper mapper = buildObjectMapper();
int lastIdx = (int)(start + rows <= total ? start + rows : total);
StringTerms terms = response.getAggregations().get(HITS_AGGREGATION);
List<Terms.Bucket> termBuckets = terms.getBuckets().subList(start, lastIdx);
for (Terms.Bucket bucket : termBuckets) {
TopHits hits = bucket.getAggregations().get(HITS_AGGREGATION);
SearchHit hit = hits.getHits().getAt(0);
documentMap.put(hit.getId(), extractDocument(mapper, hit));
}
// Populate annotation data for the document
lookupAnnotationFields(documentMap);
docs = new ArrayList<>(documentMap.values());
}
return new ResultsList<>(docs, start, (start / rows), total);
}
@Override
public ResultsList<Document> searchByEfoUri(int start, int rows, String term, String... uris) throws SearchEngineException {
return null;
}
private List<String> parseAdditionalFields(List<String> additional) {
List<String> parsed;
if (additional == null || additional.size() == 0) {
parsed = null;
} else {
// Need to add annotation field name to all additional fields
// Also need to handle hard-coded Solr field names
parsed = additional.stream()
.map(add -> add.replaceAll("^efo_uri_(.*)_t$", "$1"))
.map(add -> getAnnotationField() + "." + add)
.collect(Collectors.toList());
}
return parsed;
}
private ObjectMapper buildObjectMapper() {
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
return mapper;
}
private Document extractDocument(ObjectMapper mapper, SearchHit hit) throws SearchEngineException {
Document doc;
try {
doc = mapper.readValue(hit.getSourceAsString(), Document.class);
} catch (IOException e) {
LOGGER.error("Error reading document from source: {}", e.getMessage());
throw new SearchEngineException(e);
}
return doc;
}
private void lookupAnnotationFields(Map<String, Document> idMap) {
QueryBuilder qb = QueryBuilders.idsQuery(getDocumentType()).addIds(idMap.keySet());
SearchRequestBuilder srb = getClient().prepareSearch(getIndexName())
.addFields("*")
.setQuery(qb)
.setSize(idMap.size());
LOGGER.debug("Annotation field lookup query: {}", srb.toString());
SearchResponse response = srb.execute().actionGet();
for (SearchHit hit : response.getHits().getHits()) {
populateAnnotationFields(hit, idMap.get(hit.getId()));
}
}
private void populateAnnotationFields(SearchHit hit, Document doc) {
if (doc != null && hit.fields().size() > 0) {
for (Map.Entry<String, SearchHitField> fieldEntry : hit.fields().entrySet()) {
if (fieldEntry.getKey().startsWith(getAnnotationField())) {
String fieldName = fieldEntry.getKey();
switch (fieldName) {
case "efo_uri.label":
doc.setEfoLabels(getStringValues(fieldEntry.getValue().getValues()));
break;
case "efo_uri.child_labels":
doc.setChildLabels(getStringValues(fieldEntry.getValue().getValues()));
break;
case "efo_uri.parent_labels":
doc.setParentLabels(getStringValues(fieldEntry.getValue().getValues()));
break;
default:
String shortName = fieldName.substring("efo_uri.".length());
if (fieldName.endsWith("_rel_uris")) {
doc.getRelatedIris().put(shortName, getStringValues(fieldEntry.getValue().getValues()));
} else if (fieldName.endsWith("_rel_labels")) {
List<String> labels = getStringValues(fieldEntry.getValue().getValues());
if (labels != null) {
doc.getRelatedLabels().put(shortName, labels);
}
}
}
}
}
}
}
private List<String> getStringValues(List<Object> fieldValues) {
List<String> retList;
if (fieldValues == null || fieldValues.size() == 0) {
retList = null;
} else {
retList = fieldValues.stream().map(Object::toString).collect(Collectors.toList());
}
return retList;
}
}