/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.functionscore;
import org.apache.lucene.util.ArrayUtil;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.RandomScoreFunctionBuilder;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.script.MockScriptPlugin;
import org.elasticsearch.script.ScoreAccessor;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESIntegTestCase;
import org.hamcrest.CoreMatchers;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.index.query.QueryBuilders.functionScoreQuery;
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
import static org.elasticsearch.index.query.QueryBuilders.matchQuery;
import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.fieldValueFactorFunction;
import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.randomFunction;
import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.scriptFunction;
import static org.elasticsearch.script.MockScriptPlugin.NAME;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.nullValue;
public class RandomScoreFunctionIT extends ESIntegTestCase {
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(CustomScriptPlugin.class);
}
public static class CustomScriptPlugin extends MockScriptPlugin {
@Override
@SuppressWarnings("unchecked")
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
Map<String, Function<Map<String, Object>, Object>> scripts = new HashMap<>();
scripts.put("log(doc['index'].value + (factor * _score))",
vars -> scoringScript(vars, ScoreAccessor::doubleValue));
scripts.put("log(doc['index'].value + (factor * _score.intValue()))",
vars -> scoringScript(vars, ScoreAccessor::intValue));
scripts.put("log(doc['index'].value + (factor * _score.longValue()))",
vars -> scoringScript(vars, ScoreAccessor::longValue));
scripts.put("log(doc['index'].value + (factor * _score.floatValue()))",
vars -> scoringScript(vars, ScoreAccessor::floatValue));
scripts.put("log(doc['index'].value + (factor * _score.doubleValue()))",
vars -> scoringScript(vars, ScoreAccessor::doubleValue));
return scripts;
}
@SuppressWarnings("unchecked")
static Double scoringScript(Map<String, Object> vars, Function<ScoreAccessor, Number> scoring) {
Map<?, ?> doc = (Map) vars.get("doc");
Double index = ((Number) ((ScriptDocValues<?>) doc.get("index")).getValues().get(0)).doubleValue();
Double score = scoring.apply((ScoreAccessor) vars.get("_score")).doubleValue();
Integer factor = (Integer) vars.get("factor");
return Math.log(index + (factor * score));
}
}
public void testConsistentHitsWithSameSeed() throws Exception {
createIndex("test");
ensureGreen(); // make sure we are done otherwise preference could change?
int docCount = randomIntBetween(100, 200);
for (int i = 0; i < docCount; i++) {
index("test", "type", "" + i, jsonBuilder().startObject().endObject());
}
flush();
refresh();
int outerIters = scaledRandomIntBetween(10, 20);
for (int o = 0; o < outerIters; o++) {
final int seed = randomInt();
String preference = randomRealisticUnicodeOfLengthBetween(1, 10); // at least one char!!
// randomPreference should not start with '_' (reserved for known preference types (e.g. _shards, _primary)
while (preference.startsWith("_")) {
preference = randomRealisticUnicodeOfLengthBetween(1, 10);
}
int innerIters = scaledRandomIntBetween(2, 5);
SearchHit[] hits = null;
for (int i = 0; i < innerIters; i++) {
SearchResponse searchResponse = client().prepareSearch()
.setSize(docCount) // get all docs otherwise we are prone to tie-breaking
.setPreference(preference)
.setQuery(functionScoreQuery(matchAllQuery(), randomFunction(seed)))
.execute().actionGet();
assertThat("Failures " + Arrays.toString(searchResponse.getShardFailures()),
searchResponse.getShardFailures().length, CoreMatchers.equalTo(0));
final int hitCount = searchResponse.getHits().getHits().length;
final SearchHit[] currentHits = searchResponse.getHits().getHits();
ArrayUtil.timSort(currentHits, (o1, o2) -> {
// for tie-breaking we have to resort here since if the score is
// identical we rely on collection order which might change.
int cmp = Float.compare(o1.getScore(), o2.getScore());
return cmp == 0 ? o1.getId().compareTo(o2.getId()) : cmp;
});
if (i == 0) {
assertThat(hits, nullValue());
hits = currentHits;
} else {
assertThat(hits.length, equalTo(searchResponse.getHits().getHits().length));
for (int j = 0; j < hitCount; j++) {
assertThat("" + j, currentHits[j].getScore(), equalTo(hits[j].getScore()));
assertThat("" + j, currentHits[j].getId(), equalTo(hits[j].getId()));
}
}
// randomly change some docs to get them in different segments
int numDocsToChange = randomIntBetween(20, 50);
while (numDocsToChange > 0) {
int doc = randomInt(docCount-1);// watch out this is inclusive the max values!
index("test", "type", "" + doc, jsonBuilder().startObject().endObject());
--numDocsToChange;
}
flush();
refresh();
}
}
}
public void testScoreAccessWithinScript() throws Exception {
assertAcked(prepareCreate("test").addMapping("type", "body", "type=text", "index",
"type=" + randomFrom("short", "float", "long", "integer", "double")));
int docCount = randomIntBetween(100, 200);
for (int i = 0; i < docCount; i++) {
client().prepareIndex("test", "type", "" + i)
// we add 1 to the index field to make sure that the scripts below never compute log(0)
.setSource("body", randomFrom(Arrays.asList("foo", "bar", "baz")), "index", i + 1)
.get();
}
refresh();
Map<String, Object> params = new HashMap<>();
params.put("factor", randomIntBetween(2, 4));
// Test for accessing _score
Script script = new Script(ScriptType.INLINE, NAME, "log(doc['index'].value + (factor * _score))", params);
SearchResponse resp = client()
.prepareSearch("test")
.setQuery(
functionScoreQuery(matchQuery("body", "foo"),
new FunctionScoreQueryBuilder.FilterFunctionBuilder[] {
new FunctionScoreQueryBuilder.FilterFunctionBuilder(fieldValueFactorFunction("index").factor(2)),
new FunctionScoreQueryBuilder.FilterFunctionBuilder(scriptFunction(script))
}
))
.get();
assertNoFailures(resp);
SearchHit firstHit = resp.getHits().getAt(0);
assertThat(firstHit.getScore(), greaterThan(1f));
// Test for accessing _score.intValue()
script = new Script(ScriptType.INLINE, NAME, "log(doc['index'].value + (factor * _score.intValue()))", params);
resp = client()
.prepareSearch("test")
.setQuery(
functionScoreQuery(matchQuery("body", "foo"),
new FunctionScoreQueryBuilder.FilterFunctionBuilder[] {
new FunctionScoreQueryBuilder.FilterFunctionBuilder(fieldValueFactorFunction("index").factor(2)),
new FunctionScoreQueryBuilder.FilterFunctionBuilder(scriptFunction(script))
}
))
.get();
assertNoFailures(resp);
firstHit = resp.getHits().getAt(0);
assertThat(firstHit.getScore(), greaterThan(1f));
// Test for accessing _score.longValue()
script = new Script(ScriptType.INLINE, NAME, "log(doc['index'].value + (factor * _score.longValue()))", params);
resp = client()
.prepareSearch("test")
.setQuery(
functionScoreQuery(matchQuery("body", "foo"),
new FunctionScoreQueryBuilder.FilterFunctionBuilder[] {
new FunctionScoreQueryBuilder.FilterFunctionBuilder(fieldValueFactorFunction("index").factor(2)),
new FunctionScoreQueryBuilder.FilterFunctionBuilder(scriptFunction(script))
}
))
.get();
assertNoFailures(resp);
firstHit = resp.getHits().getAt(0);
assertThat(firstHit.getScore(), greaterThan(1f));
// Test for accessing _score.floatValue()
script = new Script(ScriptType.INLINE, NAME, "log(doc['index'].value + (factor * _score.floatValue()))", params);
resp = client()
.prepareSearch("test")
.setQuery(
functionScoreQuery(matchQuery("body", "foo"),
new FunctionScoreQueryBuilder.FilterFunctionBuilder[] {
new FunctionScoreQueryBuilder.FilterFunctionBuilder(fieldValueFactorFunction("index").factor(2)),
new FunctionScoreQueryBuilder.FilterFunctionBuilder(scriptFunction(script))
}
))
.get();
assertNoFailures(resp);
firstHit = resp.getHits().getAt(0);
assertThat(firstHit.getScore(), greaterThan(1f));
// Test for accessing _score.doubleValue()
script = new Script(ScriptType.INLINE, NAME, "log(doc['index'].value + (factor * _score.doubleValue()))", params);
resp = client()
.prepareSearch("test")
.setQuery(
functionScoreQuery(matchQuery("body", "foo"),
new FunctionScoreQueryBuilder.FilterFunctionBuilder[] {
new FunctionScoreQueryBuilder.FilterFunctionBuilder(fieldValueFactorFunction("index").factor(2)),
new FunctionScoreQueryBuilder.FilterFunctionBuilder(scriptFunction(script))
}
))
.get();
assertNoFailures(resp);
firstHit = resp.getHits().getAt(0);
assertThat(firstHit.getScore(), greaterThan(1f));
}
public void testSeedReportedInExplain() throws Exception {
createIndex("test");
ensureGreen();
index("test", "type", "1", jsonBuilder().startObject().endObject());
flush();
refresh();
int seed = 12345678;
SearchResponse resp = client().prepareSearch("test")
.setQuery(functionScoreQuery(matchAllQuery(), randomFunction(seed)))
.setExplain(true)
.get();
assertNoFailures(resp);
assertEquals(1, resp.getHits().getTotalHits());
SearchHit firstHit = resp.getHits().getAt(0);
assertThat(firstHit.getExplanation().toString(), containsString("" + seed));
}
public void testNoDocs() throws Exception {
createIndex("test");
ensureGreen();
SearchResponse resp = client().prepareSearch("test")
.setQuery(functionScoreQuery(matchAllQuery(), randomFunction(1234)))
.get();
assertNoFailures(resp);
assertEquals(0, resp.getHits().getTotalHits());
}
public void testScoreRange() throws Exception {
// all random scores should be in range [0.0, 1.0]
createIndex("test");
ensureGreen();
int docCount = randomIntBetween(100, 200);
for (int i = 0; i < docCount; i++) {
String id = randomRealisticUnicodeOfCodepointLengthBetween(1, 50);
index("test", "type", id, jsonBuilder().startObject().endObject());
}
flush();
refresh();
int iters = scaledRandomIntBetween(10, 20);
for (int i = 0; i < iters; ++i) {
int seed = randomInt();
SearchResponse searchResponse = client().prepareSearch()
.setQuery(functionScoreQuery(matchAllQuery(), randomFunction(seed)))
.setSize(docCount)
.execute().actionGet();
assertNoFailures(searchResponse);
for (SearchHit hit : searchResponse.getHits().getHits()) {
assertThat(hit.getScore(), allOf(greaterThanOrEqualTo(0.0f), lessThanOrEqualTo(1.0f)));
}
}
}
public void testSeeds() throws Exception {
createIndex("test");
ensureGreen();
final int docCount = randomIntBetween(100, 200);
for (int i = 0; i < docCount; i++) {
index("test", "type", "" + i, jsonBuilder().startObject().endObject());
}
flushAndRefresh();
assertNoFailures(client().prepareSearch()
.setSize(docCount) // get all docs otherwise we are prone to tie-breaking
.setQuery(functionScoreQuery(matchAllQuery(), randomFunction(randomInt())))
.execute().actionGet());
assertNoFailures(client().prepareSearch()
.setSize(docCount) // get all docs otherwise we are prone to tie-breaking
.setQuery(functionScoreQuery(matchAllQuery(), randomFunction(randomLong())))
.execute().actionGet());
assertNoFailures(client().prepareSearch()
.setSize(docCount) // get all docs otherwise we are prone to tie-breaking
.setQuery(functionScoreQuery(matchAllQuery(), randomFunction(randomRealisticUnicodeOfLengthBetween(10, 20))))
.execute().actionGet());
}
public void checkDistribution() throws Exception {
int count = 10000;
assertAcked(prepareCreate("test"));
ensureGreen();
for (int i = 0; i < count; i++) {
index("test", "type", "" + i, jsonBuilder().startObject().endObject());
}
flush();
refresh();
int[] matrix = new int[count];
for (int i = 0; i < count; i++) {
SearchResponse searchResponse = client().prepareSearch()
.setQuery(functionScoreQuery(matchAllQuery(), new RandomScoreFunctionBuilder()))
.execute().actionGet();
matrix[Integer.valueOf(searchResponse.getHits().getAt(0).getId())]++;
}
int filled = 0;
int maxRepeat = 0;
int sumRepeat = 0;
for (int i = 0; i < matrix.length; i++) {
int value = matrix[i];
sumRepeat += value;
maxRepeat = Math.max(maxRepeat, value);
if (value > 0) {
filled++;
}
}
logger.info("max repeat: {}", maxRepeat);
logger.info("avg repeat: {}", sumRepeat / (double) filled);
logger.info("distribution: {}", filled / (double) count);
int percentile50 = filled / 2;
int percentile25 = (filled / 4);
int percentile75 = percentile50 + percentile25;
int sum = 0;
for (int i = 0; i < matrix.length; i++) {
if (matrix[i] == 0) {
continue;
}
sum += i * matrix[i];
if (percentile50 == 0) {
logger.info("median: {}", i);
} else if (percentile25 == 0) {
logger.info("percentile_25: {}", i);
} else if (percentile75 == 0) {
logger.info("percentile_75: {}", i);
}
percentile50--;
percentile25--;
percentile75--;
}
logger.info("mean: {}", sum / (double) count);
}
}