/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.query.functionscore;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.lucene.search.function.CombineFunction;
import org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery;
import org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery.FilterFunction;
import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery;
import org.elasticsearch.common.lucene.search.function.ScoreFunction;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentLocation;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.QueryShardContext;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
/**
* A query that uses a filters with a script associated with them to compute the
* score.
*/
public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScoreQueryBuilder> {
public static final String NAME = "function_score";
// For better readability of error message
static final String MISPLACED_FUNCTION_MESSAGE_PREFIX = "you can either define [functions] array or a single function, not both. ";
public static final ParseField WEIGHT_FIELD = new ParseField("weight");
public static final ParseField QUERY_FIELD = new ParseField("query");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField FUNCTIONS_FIELD = new ParseField("functions");
public static final ParseField SCORE_MODE_FIELD = new ParseField("score_mode");
public static final ParseField BOOST_MODE_FIELD = new ParseField("boost_mode");
public static final ParseField MAX_BOOST_FIELD = new ParseField("max_boost");
public static final ParseField MIN_SCORE_FIELD = new ParseField("min_score");
public static final CombineFunction DEFAULT_BOOST_MODE = CombineFunction.MULTIPLY;
public static final FiltersFunctionScoreQuery.ScoreMode DEFAULT_SCORE_MODE = FiltersFunctionScoreQuery.ScoreMode.MULTIPLY;
private final QueryBuilder query;
private float maxBoost = FunctionScoreQuery.DEFAULT_MAX_BOOST;
private FiltersFunctionScoreQuery.ScoreMode scoreMode = DEFAULT_SCORE_MODE;
private CombineFunction boostMode;
private Float minScore = null;
private final FilterFunctionBuilder[] filterFunctionBuilders;
/**
* Creates a function_score query without functions
*
* @param query the query that needs to be custom scored
*/
public FunctionScoreQueryBuilder(QueryBuilder query) {
this(query, new FilterFunctionBuilder[0]);
}
/**
* Creates a function_score query that executes the provided filters and functions on all documents
*
* @param filterFunctionBuilders the filters and functions
*/
public FunctionScoreQueryBuilder(FilterFunctionBuilder[] filterFunctionBuilders) {
this(new MatchAllQueryBuilder(), filterFunctionBuilders);
}
/**
* Creates a function_score query that will execute the function provided on all documents
*
* @param scoreFunctionBuilder score function that is executed
*/
public FunctionScoreQueryBuilder(ScoreFunctionBuilder<?> scoreFunctionBuilder) {
this(new MatchAllQueryBuilder(), new FilterFunctionBuilder[]{new FilterFunctionBuilder(scoreFunctionBuilder)});
}
/**
* Creates a function_score query that will execute the function provided in the context of the provided query
*
* @param query the query to custom score
* @param scoreFunctionBuilder score function that is executed
*/
public FunctionScoreQueryBuilder(QueryBuilder query, ScoreFunctionBuilder<?> scoreFunctionBuilder) {
this(query, new FilterFunctionBuilder[]{new FilterFunctionBuilder(scoreFunctionBuilder)});
}
/**
* Creates a function_score query that executes the provided filters and functions on documents that match a query.
*
* @param query the query that defines which documents the function_score query will be executed on.
* @param filterFunctionBuilders the filters and functions
*/
public FunctionScoreQueryBuilder(QueryBuilder query, FilterFunctionBuilder[] filterFunctionBuilders) {
if (query == null) {
throw new IllegalArgumentException("function_score: query must not be null");
}
if (filterFunctionBuilders == null) {
throw new IllegalArgumentException("function_score: filters and functions array must not be null");
}
for (FilterFunctionBuilder filterFunctionBuilder : filterFunctionBuilders) {
if (filterFunctionBuilder == null) {
throw new IllegalArgumentException("function_score: each filter and function must not be null");
}
}
this.query = query;
this.filterFunctionBuilders = filterFunctionBuilders;
}
/**
* Read from a stream.
*/
public FunctionScoreQueryBuilder(StreamInput in) throws IOException {
super(in);
query = in.readNamedWriteable(QueryBuilder.class);
filterFunctionBuilders = in.readList(FilterFunctionBuilder::new).toArray(new FilterFunctionBuilder[0]);
maxBoost = in.readFloat();
minScore = in.readOptionalFloat();
boostMode = in.readOptionalWriteable(CombineFunction::readFromStream);
scoreMode = FiltersFunctionScoreQuery.ScoreMode.readFromStream(in);
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(query);
out.writeList(Arrays.asList(filterFunctionBuilders));
out.writeFloat(maxBoost);
out.writeOptionalFloat(minScore);
out.writeOptionalWriteable(boostMode);
scoreMode.writeTo(out);
}
/**
* Returns the query that defines which documents the function_score query will be executed on.
*/
public QueryBuilder query() {
return this.query;
}
/**
* Returns the filters and functions
*/
public FilterFunctionBuilder[] filterFunctionBuilders() {
return this.filterFunctionBuilders;
}
/**
* Score mode defines how results of individual score functions will be aggregated.
* @see org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery.ScoreMode
*/
public FunctionScoreQueryBuilder scoreMode(FiltersFunctionScoreQuery.ScoreMode scoreMode) {
if (scoreMode == null) {
throw new IllegalArgumentException("[" + NAME + "] requires 'score_mode' field");
}
this.scoreMode = scoreMode;
return this;
}
/**
* Returns the score mode, meaning how results of individual score functions will be aggregated.
* @see org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery.ScoreMode
*/
public FiltersFunctionScoreQuery.ScoreMode scoreMode() {
return this.scoreMode;
}
/**
* Boost mode defines how the combined result of score functions will influence the final score together with the sub query score.
* @see CombineFunction
*/
public FunctionScoreQueryBuilder boostMode(CombineFunction combineFunction) {
if (combineFunction == null) {
throw new IllegalArgumentException("[" + NAME + "] requires 'boost_mode' field");
}
this.boostMode = combineFunction;
return this;
}
/**
* Returns the boost mode, meaning how the combined result of score functions will influence the final score together with the sub query
* score.
*
* @see CombineFunction
*/
public CombineFunction boostMode() {
return this.boostMode;
}
/**
* Sets the maximum boost that will be applied by function score.
*/
public FunctionScoreQueryBuilder maxBoost(float maxBoost) {
this.maxBoost = maxBoost;
return this;
}
/**
* Returns the maximum boost that will be applied by function score.
*/
public float maxBoost() {
return this.maxBoost;
}
@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
if (query != null) {
builder.field(QUERY_FIELD.getPreferredName());
query.toXContent(builder, params);
}
builder.startArray(FUNCTIONS_FIELD.getPreferredName());
for (FilterFunctionBuilder filterFunctionBuilder : filterFunctionBuilders) {
filterFunctionBuilder.toXContent(builder, params);
}
builder.endArray();
builder.field(SCORE_MODE_FIELD.getPreferredName(), scoreMode.name().toLowerCase(Locale.ROOT));
if (boostMode != null) {
builder.field(BOOST_MODE_FIELD.getPreferredName(), boostMode.name().toLowerCase(Locale.ROOT));
}
builder.field(MAX_BOOST_FIELD.getPreferredName(), maxBoost);
if (minScore != null) {
builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
}
printBoostAndQueryName(builder);
builder.endObject();
}
public FunctionScoreQueryBuilder setMinScore(float minScore) {
this.minScore = minScore;
return this;
}
public Float getMinScore() {
return this.minScore;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
protected boolean doEquals(FunctionScoreQueryBuilder other) {
return Objects.equals(this.query, other.query) &&
Arrays.equals(this.filterFunctionBuilders, other.filterFunctionBuilders) &&
Objects.equals(this.boostMode, other.boostMode) &&
Objects.equals(this.scoreMode, other.scoreMode) &&
Objects.equals(this.minScore, other.minScore) &&
Objects.equals(this.maxBoost, other.maxBoost);
}
@Override
protected int doHashCode() {
return Objects.hash(this.query, Arrays.hashCode(this.filterFunctionBuilders), this.boostMode, this.scoreMode, this.minScore,
this.maxBoost);
}
@Override
protected Query doToQuery(QueryShardContext context) throws IOException {
FilterFunction[] filterFunctions = new FilterFunction[filterFunctionBuilders.length];
int i = 0;
for (FilterFunctionBuilder filterFunctionBuilder : filterFunctionBuilders) {
Query filter = filterFunctionBuilder.getFilter().toQuery(context);
ScoreFunction scoreFunction = filterFunctionBuilder.getScoreFunction().toFunction(context);
filterFunctions[i++] = new FilterFunction(filter, scoreFunction);
}
Query query = this.query.toQuery(context);
if (query == null) {
query = new MatchAllDocsQuery();
}
// handle cases where only one score function and no filter was provided. In this case we create a FunctionScoreQuery.
if (filterFunctions.length == 0 || filterFunctions.length == 1
&& (this.filterFunctionBuilders[0].getFilter().getName().equals(MatchAllQueryBuilder.NAME))) {
ScoreFunction function = filterFunctions.length == 0 ? null : filterFunctions[0].function;
CombineFunction combineFunction = this.boostMode;
if (combineFunction == null) {
if (function != null) {
combineFunction = function.getDefaultScoreCombiner();
} else {
combineFunction = DEFAULT_BOOST_MODE;
}
}
return new FunctionScoreQuery(query, function, minScore, combineFunction, maxBoost);
}
// in all other cases we create a FiltersFunctionScoreQuery
CombineFunction boostMode = this.boostMode == null ? DEFAULT_BOOST_MODE : this.boostMode;
return new FiltersFunctionScoreQuery(query, scoreMode, filterFunctions, maxBoost, minScore, boostMode);
}
/**
* Function to be associated with an optional filter, meaning it will be executed only for the documents
* that match the given filter.
*/
public static class FilterFunctionBuilder implements ToXContent, Writeable {
private final QueryBuilder filter;
private final ScoreFunctionBuilder<?> scoreFunction;
public FilterFunctionBuilder(ScoreFunctionBuilder<?> scoreFunctionBuilder) {
this(new MatchAllQueryBuilder(), scoreFunctionBuilder);
}
public FilterFunctionBuilder(QueryBuilder filter, ScoreFunctionBuilder<?> scoreFunction) {
if (filter == null) {
throw new IllegalArgumentException("function_score: filter must not be null");
}
if (scoreFunction == null) {
throw new IllegalArgumentException("function_score: function must not be null");
}
this.filter = filter;
this.scoreFunction = scoreFunction;
}
/**
* Read from a stream.
*/
public FilterFunctionBuilder(StreamInput in) throws IOException {
filter = in.readNamedWriteable(QueryBuilder.class);
scoreFunction = in.readNamedWriteable(ScoreFunctionBuilder.class);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(filter);
out.writeNamedWriteable(scoreFunction);
}
public QueryBuilder getFilter() {
return filter;
}
public ScoreFunctionBuilder<?> getScoreFunction() {
return scoreFunction;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FILTER_FIELD.getPreferredName());
filter.toXContent(builder, params);
scoreFunction.toXContent(builder, params);
builder.endObject();
return builder;
}
@Override
public int hashCode() {
return Objects.hash(filter, scoreFunction);
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
FilterFunctionBuilder that = (FilterFunctionBuilder) obj;
return Objects.equals(this.filter, that.filter) &&
Objects.equals(this.scoreFunction, that.scoreFunction);
}
public FilterFunctionBuilder rewrite(QueryRewriteContext context) throws IOException {
QueryBuilder rewrite = filter.rewrite(context);
if (rewrite != filter) {
return new FilterFunctionBuilder(rewrite, scoreFunction);
}
return this;
}
}
@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
QueryBuilder queryBuilder = this.query.rewrite(queryRewriteContext);
FilterFunctionBuilder[] rewrittenBuilders = new FilterFunctionBuilder[this.filterFunctionBuilders.length];
boolean rewritten = false;
for (int i = 0; i < rewrittenBuilders.length; i++) {
FilterFunctionBuilder rewrite = filterFunctionBuilders[i].rewrite(queryRewriteContext);
rewritten |= rewrite != filterFunctionBuilders[i];
rewrittenBuilders[i] = rewrite;
}
if (queryBuilder != query || rewritten) {
FunctionScoreQueryBuilder newQueryBuilder = new FunctionScoreQueryBuilder(queryBuilder, rewrittenBuilders);
newQueryBuilder.scoreMode = scoreMode;
newQueryBuilder.minScore = minScore;
newQueryBuilder.maxBoost = maxBoost;
newQueryBuilder.boostMode = boostMode;
return newQueryBuilder;
}
return this;
}
@Override
protected void extractInnerHitBuilders(Map<String, InnerHitBuilder> innerHits) {
InnerHitBuilder.extractInnerHits(query(), innerHits);
}
public static FunctionScoreQueryBuilder fromXContent(QueryParseContext parseContext) throws IOException {
XContentParser parser = parseContext.parser();
QueryBuilder query = null;
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
String queryName = null;
FiltersFunctionScoreQuery.ScoreMode scoreMode = FunctionScoreQueryBuilder.DEFAULT_SCORE_MODE;
float maxBoost = FunctionScoreQuery.DEFAULT_MAX_BOOST;
Float minScore = null;
String currentFieldName = null;
XContentParser.Token token;
CombineFunction combineFunction = null;
// Either define array of functions and filters or only one function
boolean functionArrayFound = false;
boolean singleFunctionFound = false;
String singleFunctionName = null;
List<FunctionScoreQueryBuilder.FilterFunctionBuilder> filterFunctionBuilders = new ArrayList<>();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token == XContentParser.Token.START_OBJECT) {
if (QUERY_FIELD.match(currentFieldName)) {
if (query != null) {
throw new ParsingException(parser.getTokenLocation(), "failed to parse [{}] query. [query] is already defined.",
NAME);
}
query = parseContext.parseInnerQueryBuilder();
} else {
if (singleFunctionFound) {
throw new ParsingException(parser.getTokenLocation(),
"failed to parse [{}] query. already found function [{}], now encountering [{}]. use [functions] "
+ "array if you want to define several functions.",
NAME, singleFunctionName, currentFieldName);
}
if (functionArrayFound) {
String errorString = "already found [functions] array, now encountering [" + currentFieldName + "].";
handleMisplacedFunctionsDeclaration(parser.getTokenLocation(), errorString);
}
singleFunctionFound = true;
singleFunctionName = currentFieldName;
ScoreFunctionBuilder<?> scoreFunction = parser.namedObject(ScoreFunctionBuilder.class, currentFieldName,
parseContext);
filterFunctionBuilders.add(new FunctionScoreQueryBuilder.FilterFunctionBuilder(scoreFunction));
}
} else if (token == XContentParser.Token.START_ARRAY) {
if (FUNCTIONS_FIELD.match(currentFieldName)) {
if (singleFunctionFound) {
String errorString = "already found [" + singleFunctionName + "], now encountering [functions].";
handleMisplacedFunctionsDeclaration(parser.getTokenLocation(), errorString);
}
functionArrayFound = true;
currentFieldName = parseFiltersAndFunctions(parseContext, filterFunctionBuilders);
} else {
throw new ParsingException(parser.getTokenLocation(), "failed to parse [{}] query. array [{}] is not supported",
NAME, currentFieldName);
}
} else if (token.isValue()) {
if (SCORE_MODE_FIELD.match(currentFieldName)) {
scoreMode = FiltersFunctionScoreQuery.ScoreMode.fromString(parser.text());
} else if (BOOST_MODE_FIELD.match(currentFieldName)) {
combineFunction = CombineFunction.fromString(parser.text());
} else if (MAX_BOOST_FIELD.match(currentFieldName)) {
maxBoost = parser.floatValue();
} else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName)) {
boost = parser.floatValue();
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName)) {
queryName = parser.text();
} else if (MIN_SCORE_FIELD.match(currentFieldName)) {
minScore = parser.floatValue();
} else {
if (singleFunctionFound) {
throw new ParsingException(parser.getTokenLocation(),
"failed to parse [{}] query. already found function [{}], now encountering [{}]. use [functions] array "
+ "if you want to define several functions.",
NAME, singleFunctionName, currentFieldName);
}
if (functionArrayFound) {
String errorString = "already found [functions] array, now encountering [" + currentFieldName + "].";
handleMisplacedFunctionsDeclaration(parser.getTokenLocation(), errorString);
}
if (WEIGHT_FIELD.match(currentFieldName)) {
filterFunctionBuilders.add(
new FunctionScoreQueryBuilder.FilterFunctionBuilder(new WeightBuilder().setWeight(parser.floatValue())));
singleFunctionFound = true;
singleFunctionName = currentFieldName;
} else {
throw new ParsingException(parser.getTokenLocation(), "failed to parse [{}] query. field [{}] is not supported",
NAME, currentFieldName);
}
}
}
}
if (query == null) {
query = new MatchAllQueryBuilder();
}
FunctionScoreQueryBuilder functionScoreQueryBuilder = new FunctionScoreQueryBuilder(query,
filterFunctionBuilders.toArray(new FunctionScoreQueryBuilder.FilterFunctionBuilder[filterFunctionBuilders.size()]));
if (combineFunction != null) {
functionScoreQueryBuilder.boostMode(combineFunction);
}
functionScoreQueryBuilder.scoreMode(scoreMode);
functionScoreQueryBuilder.maxBoost(maxBoost);
if (minScore != null) {
functionScoreQueryBuilder.setMinScore(minScore);
}
functionScoreQueryBuilder.boost(boost);
functionScoreQueryBuilder.queryName(queryName);
return functionScoreQueryBuilder;
}
private static void handleMisplacedFunctionsDeclaration(XContentLocation contentLocation, String errorString) {
throw new ParsingException(contentLocation, "failed to parse [{}] query. [{}]", NAME,
MISPLACED_FUNCTION_MESSAGE_PREFIX + errorString);
}
private static String parseFiltersAndFunctions(QueryParseContext parseContext,
List<FunctionScoreQueryBuilder.FilterFunctionBuilder> filterFunctionBuilders) throws IOException {
String currentFieldName = null;
XContentParser.Token token;
XContentParser parser = parseContext.parser();
while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) {
QueryBuilder filter = null;
ScoreFunctionBuilder<?> scoreFunction = null;
Float functionWeight = null;
if (token != XContentParser.Token.START_OBJECT) {
throw new ParsingException(parser.getTokenLocation(),
"failed to parse [{}]. malformed query, expected a [{}] while parsing functions but got a [{}] instead",
XContentParser.Token.START_OBJECT, token, NAME);
} else {
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token == XContentParser.Token.START_OBJECT) {
if (FILTER_FIELD.match(currentFieldName)) {
filter = parseContext.parseInnerQueryBuilder();
} else {
if (scoreFunction != null) {
throw new ParsingException(parser.getTokenLocation(),
"failed to parse function_score functions. already found [{}], now encountering [{}].",
scoreFunction.getName(), currentFieldName);
}
scoreFunction = parser.namedObject(ScoreFunctionBuilder.class, currentFieldName, parseContext);
}
} else if (token.isValue()) {
if (WEIGHT_FIELD.match(currentFieldName)) {
functionWeight = parser.floatValue();
} else {
throw new ParsingException(parser.getTokenLocation(), "failed to parse [{}] query. field [{}] is not supported",
NAME, currentFieldName);
}
}
}
if (functionWeight != null) {
if (scoreFunction == null) {
scoreFunction = new WeightBuilder().setWeight(functionWeight);
} else {
scoreFunction.setWeight(functionWeight);
}
}
}
if (filter == null) {
filter = new MatchAllQueryBuilder();
}
if (scoreFunction == null) {
throw new ParsingException(parser.getTokenLocation(),
"failed to parse [{}] query. an entry in functions list is missing a function.", NAME);
}
filterFunctionBuilders.add(new FunctionScoreQueryBuilder.FilterFunctionBuilder(filter, scoreFunction));
}
return currentFieldName;
}
}