/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.metrics.tophits;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.util.Comparator.comparing;
public class InternalTopHitsTests extends InternalAggregationTestCase<InternalTopHits> {
/**
* Should the test instances look like they are sorted by some fields (true) or sorted by score (false). Set here because these need
* to be the same across the entirety of {@link #testReduceRandom()}.
*/
private final boolean testInstancesLookSortedByField = randomBoolean();
/**
* Fields shared by all instances created by {@link #createTestInstance(String, List, Map)}.
*/
private final SortField[] testInstancesSortFields = testInstancesLookSortedByField ? randomSortFields() : new SortField[0];
@Override
protected InternalTopHits createTestInstance(String name, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
int from = 0;
int requestedSize = between(1, 40);
int actualSize = between(0, requestedSize);
float maxScore = Float.MIN_VALUE;
ScoreDoc[] scoreDocs = new ScoreDoc[actualSize];
SearchHit[] hits = new SearchHit[actualSize];
Set<Integer> usedDocIds = new HashSet<>();
for (int i = 0; i < actualSize; i++) {
float score = randomFloat();
maxScore = max(maxScore, score);
int docId = randomValueOtherThanMany(usedDocIds::contains, () -> between(0, IndexWriter.MAX_DOCS));
usedDocIds.add(docId);
Map<String, SearchHitField> searchHitFields = new HashMap<>();
if (testInstancesLookSortedByField) {
Object[] fields = new Object[testInstancesSortFields.length];
for (int f = 0; f < testInstancesSortFields.length; f++) {
fields[f] = randomOfType(testInstancesSortFields[f].getType());
}
scoreDocs[i] = new FieldDoc(docId, score, fields);
} else {
scoreDocs[i] = new ScoreDoc(docId, score);
}
hits[i] = new SearchHit(docId, Integer.toString(i), new Text("test"), searchHitFields);
hits[i].score(score);
}
int totalHits = between(actualSize, 500000);
SearchHits searchHits = new SearchHits(hits, totalHits, maxScore);
TopDocs topDocs;
Arrays.sort(scoreDocs, scoreDocComparator());
if (testInstancesLookSortedByField) {
topDocs = new TopFieldDocs(totalHits, scoreDocs, testInstancesSortFields, maxScore);
} else {
topDocs = new TopDocs(totalHits, scoreDocs, maxScore);
}
return new InternalTopHits(name, from, requestedSize, topDocs, searchHits, pipelineAggregators, metaData);
}
private Object randomOfType(SortField.Type type) {
switch (type) {
case CUSTOM:
throw new UnsupportedOperationException();
case DOC:
return between(0, IndexWriter.MAX_DOCS);
case DOUBLE:
return randomDouble();
case FLOAT:
return randomFloat();
case INT:
return randomInt();
case LONG:
return randomLong();
case REWRITEABLE:
throw new UnsupportedOperationException();
case SCORE:
return randomFloat();
case STRING:
return new BytesRef(randomAlphaOfLength(5));
case STRING_VAL:
return new BytesRef(randomAlphaOfLength(5));
default:
throw new UnsupportedOperationException("Unkown SortField.Type: " + type);
}
}
@Override
protected void assertReduced(InternalTopHits reduced, List<InternalTopHits> inputs) {
SearchHits actualHits = reduced.getHits();
List<Tuple<ScoreDoc, SearchHit>> allHits = new ArrayList<>();
float maxScore = Float.MIN_VALUE;
long totalHits = 0;
for (int input = 0; input < inputs.size(); input++) {
SearchHits internalHits = inputs.get(input).getHits();
totalHits += internalHits.getTotalHits();
maxScore = max(maxScore, internalHits.getMaxScore());
for (int i = 0; i < internalHits.internalHits().length; i++) {
ScoreDoc doc = inputs.get(input).getTopDocs().scoreDocs[i];
if (testInstancesLookSortedByField) {
doc = new FieldDoc(doc.doc, doc.score, ((FieldDoc) doc).fields, input);
} else {
doc = new ScoreDoc(doc.doc, doc.score, input);
}
allHits.add(new Tuple<>(doc, internalHits.internalHits()[i]));
}
}
allHits.sort(comparing(Tuple::v1, scoreDocComparator()));
SearchHit[] expectedHitsHits = new SearchHit[min(inputs.get(0).getSize(), allHits.size())];
for (int i = 0; i < expectedHitsHits.length; i++) {
expectedHitsHits[i] = allHits.get(i).v2();
}
// Lucene's TopDocs initializes the maxScore to Float.NaN, if there is no maxScore
SearchHits expectedHits = new SearchHits(expectedHitsHits, totalHits, maxScore == Float.MIN_VALUE ? Float.NaN : maxScore);
assertEqualsWithErrorMessageFromXContent(expectedHits, actualHits);
}
@Override
protected Reader<InternalTopHits> instanceReader() {
return InternalTopHits::new;
}
private SortField[] randomSortFields() {
SortField[] sortFields = new SortField[between(1, 5)];
Set<String> usedSortFields = new HashSet<>();
for (int i = 0; i < sortFields.length; i++) {
String sortField = randomValueOtherThanMany(usedSortFields::contains, () -> randomAlphaOfLength(5));
usedSortFields.add(sortField);
SortField.Type type = randomValueOtherThanMany(t -> t == SortField.Type.CUSTOM || t == SortField.Type.REWRITEABLE,
() -> randomFrom(SortField.Type.values()));
sortFields[i] = new SortField(sortField, type);
}
return sortFields;
}
private Comparator<ScoreDoc> scoreDocComparator() {
return innerScoreDocComparator().thenComparing(s -> s.shardIndex);
}
private Comparator<ScoreDoc> innerScoreDocComparator() {
if (testInstancesLookSortedByField) {
// Values passed to getComparator shouldn't matter
@SuppressWarnings("rawtypes")
FieldComparator[] comparators = new FieldComparator[testInstancesSortFields.length];
for (int i = 0; i < testInstancesSortFields.length; i++) {
comparators[i] = testInstancesSortFields[i].getComparator(0, 0);
}
return (lhs, rhs) -> {
FieldDoc l = (FieldDoc) lhs;
FieldDoc r = (FieldDoc) rhs;
int i = 0;
while (i < l.fields.length) {
@SuppressWarnings("unchecked")
int c = comparators[i].compareValues(l.fields[i], r.fields[i]);
if (c != 0) {
return c;
}
i++;
}
return 0;
};
} else {
Comparator<ScoreDoc> comparator = comparing(d -> d.score);
return comparator.reversed();
}
}
}