/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.index.query.functionscore; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.geo.GeoDistance; import org.elasticsearch.common.geo.GeoPoint; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.search.function.CombineFunction; import org.elasticsearch.common.lucene.search.function.LeafScoreFunction; import org.elasticsearch.common.lucene.search.function.ScoreFunction; import org.elasticsearch.common.unit.DistanceUnit; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.fielddata.IndexGeoPointFieldData; import org.elasticsearch.index.fielddata.IndexNumericFieldData; import org.elasticsearch.index.fielddata.MultiGeoPointValues; import org.elasticsearch.index.fielddata.NumericDoubleValues; import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; import org.elasticsearch.index.mapper.GeoPointFieldMapper.GeoPointFieldType; import org.elasticsearch.index.mapper.DateFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.search.MultiValueMode; import java.io.IOException; import java.util.Objects; public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>> extends ScoreFunctionBuilder<DFB> { protected static final String ORIGIN = "origin"; protected static final String SCALE = "scale"; protected static final String DECAY = "decay"; protected static final String OFFSET = "offset"; public static double DEFAULT_DECAY = 0.5; public static MultiValueMode DEFAULT_MULTI_VALUE_MODE = MultiValueMode.MIN; private final String fieldName; //parsing of origin, scale, offset and decay depends on the field type, delayed to the data node that has the mapping for it private final BytesReference functionBytes; private MultiValueMode multiValueMode = DEFAULT_MULTI_VALUE_MODE; /** * Convenience constructor that converts its parameters into json to parse on the data nodes. */ protected DecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset) { this(fieldName, origin, scale, offset, DEFAULT_DECAY); } /** * Convenience constructor that converts its parameters into json to parse on the data nodes. */ protected DecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) { if (fieldName == null) { throw new IllegalArgumentException("decay function: field name must not be null"); } if (scale == null) { throw new IllegalArgumentException("decay function: scale must not be null"); } if (decay <= 0 || decay >= 1.0) { throw new IllegalStateException("decay function: decay must be in range 0..1!"); } this.fieldName = fieldName; try { XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); if (origin != null) { builder.field(ORIGIN, origin); } builder.field(SCALE, scale); if (offset != null) { builder.field(OFFSET, offset); } builder.field(DECAY, decay); builder.endObject(); this.functionBytes = builder.bytes(); } catch (IOException e) { throw new IllegalArgumentException("unable to build inner function object",e); } } protected DecayFunctionBuilder(String fieldName, BytesReference functionBytes) { if (fieldName == null) { throw new IllegalArgumentException("decay function: field name must not be null"); } if (functionBytes == null) { throw new IllegalArgumentException("decay function: function must not be null"); } this.fieldName = fieldName; this.functionBytes = functionBytes; } /** * Read from a stream. */ protected DecayFunctionBuilder(StreamInput in) throws IOException { super(in); fieldName = in.readString(); functionBytes = in.readBytesReference(); multiValueMode = MultiValueMode.readMultiValueModeFrom(in); } @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeBytesReference(functionBytes); multiValueMode.writeTo(out); } public String getFieldName() { return this.fieldName; } public BytesReference getFunctionBytes() { return this.functionBytes; } @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(getName()); builder.rawField(fieldName, functionBytes); builder.field(DecayFunctionParser.MULTI_VALUE_MODE.getPreferredName(), multiValueMode.name()); builder.endObject(); } @SuppressWarnings("unchecked") public DFB setMultiValueMode(MultiValueMode multiValueMode) { if (multiValueMode == null) { throw new IllegalArgumentException("decay function: multi_value_mode must not be null"); } this.multiValueMode = multiValueMode; return (DFB) this; } public MultiValueMode getMultiValueMode() { return this.multiValueMode; } @Override protected boolean doEquals(DFB functionBuilder) { return Objects.equals(this.fieldName, functionBuilder.getFieldName()) && Objects.equals(this.functionBytes, functionBuilder.getFunctionBytes()) && Objects.equals(this.multiValueMode, functionBuilder.getMultiValueMode()); } @Override protected int doHashCode() { return Objects.hash(this.fieldName, this.functionBytes, this.multiValueMode); } @Override protected ScoreFunction doToFunction(QueryShardContext context) throws IOException { AbstractDistanceScoreFunction scoreFunction; // EMPTY is safe because parseVariable doesn't use namedObject try (XContentParser parser = XContentFactory.xContent(functionBytes).createParser(NamedXContentRegistry.EMPTY, functionBytes)) { scoreFunction = parseVariable(fieldName, parser, context, multiValueMode); } return scoreFunction; } /** * Override this function if you want to produce your own scorer. * */ protected abstract DecayFunction getDecayFunction(); private AbstractDistanceScoreFunction parseVariable(String fieldName, XContentParser parser, QueryShardContext context, MultiValueMode mode) throws IOException { //the field must exist, else we cannot read the value for the doc later MappedFieldType fieldType = context.fieldMapper(fieldName); if (fieldType == null) { throw new ParsingException(parser.getTokenLocation(), "unknown field [{}]", fieldName); } // dates and time and geo need special handling parser.nextToken(); if (fieldType instanceof DateFieldMapper.DateFieldType) { return parseDateVariable(parser, context, fieldType, mode); } else if (fieldType instanceof GeoPointFieldType) { return parseGeoVariable(parser, context, fieldType, mode); } else if (fieldType instanceof NumberFieldMapper.NumberFieldType) { return parseNumberVariable(parser, context, fieldType, mode); } else { throw new ParsingException(parser.getTokenLocation(), "field [{}] is of type [{}], but only numeric types are supported.", fieldName, fieldType); } } private AbstractDistanceScoreFunction parseNumberVariable(XContentParser parser, QueryShardContext context, MappedFieldType fieldType, MultiValueMode mode) throws IOException { XContentParser.Token token; String parameterName = null; double scale = 0; double origin = 0; double decay = 0.5; double offset = 0.0d; boolean scaleFound = false; boolean refFound = false; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { parameterName = parser.currentName(); } else if (DecayFunctionBuilder.SCALE.equals(parameterName)) { scale = parser.doubleValue(); scaleFound = true; } else if (DecayFunctionBuilder.DECAY.equals(parameterName)) { decay = parser.doubleValue(); } else if (DecayFunctionBuilder.ORIGIN.equals(parameterName)) { origin = parser.doubleValue(); refFound = true; } else if (DecayFunctionBuilder.OFFSET.equals(parameterName)) { offset = parser.doubleValue(); } else { throw new ElasticsearchParseException("parameter [{}] not supported!", parameterName); } } if (!scaleFound || !refFound) { throw new ElasticsearchParseException("both [{}] and [{}] must be set for numeric fields.", DecayFunctionBuilder.SCALE, DecayFunctionBuilder.ORIGIN); } IndexNumericFieldData numericFieldData = context.getForField(fieldType); return new NumericFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), numericFieldData, mode); } private AbstractDistanceScoreFunction parseGeoVariable(XContentParser parser, QueryShardContext context, MappedFieldType fieldType, MultiValueMode mode) throws IOException { XContentParser.Token token; String parameterName = null; GeoPoint origin = new GeoPoint(); String scaleString = null; String offsetString = "0km"; double decay = 0.5; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { parameterName = parser.currentName(); } else if (DecayFunctionBuilder.SCALE.equals(parameterName)) { scaleString = parser.text(); } else if (DecayFunctionBuilder.ORIGIN.equals(parameterName)) { origin = GeoUtils.parseGeoPoint(parser); } else if (DecayFunctionBuilder.DECAY.equals(parameterName)) { decay = parser.doubleValue(); } else if (DecayFunctionBuilder.OFFSET.equals(parameterName)) { offsetString = parser.text(); } else { throw new ElasticsearchParseException("parameter [{}] not supported!", parameterName); } } if (origin == null || scaleString == null) { throw new ElasticsearchParseException("[{}] and [{}] must be set for geo fields.", DecayFunctionBuilder.ORIGIN, DecayFunctionBuilder.SCALE); } double scale = DistanceUnit.DEFAULT.parse(scaleString, DistanceUnit.DEFAULT); double offset = DistanceUnit.DEFAULT.parse(offsetString, DistanceUnit.DEFAULT); IndexGeoPointFieldData indexFieldData = context.getForField(fieldType); return new GeoFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), indexFieldData, mode); } private AbstractDistanceScoreFunction parseDateVariable(XContentParser parser, QueryShardContext context, MappedFieldType dateFieldType, MultiValueMode mode) throws IOException { XContentParser.Token token; String parameterName = null; String scaleString = null; String originString = null; String offsetString = "0d"; double decay = 0.5; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { parameterName = parser.currentName(); } else if (DecayFunctionBuilder.SCALE.equals(parameterName)) { scaleString = parser.text(); } else if (DecayFunctionBuilder.ORIGIN.equals(parameterName)) { originString = parser.text(); } else if (DecayFunctionBuilder.DECAY.equals(parameterName)) { decay = parser.doubleValue(); } else if (DecayFunctionBuilder.OFFSET.equals(parameterName)) { offsetString = parser.text(); } else { throw new ElasticsearchParseException("parameter [{}] not supported!", parameterName); } } long origin; if (originString == null) { origin = context.nowInMillis(); } else { origin = ((DateFieldMapper.DateFieldType) dateFieldType).parseToMilliseconds(originString, false, null, null, context); } if (scaleString == null) { throw new ElasticsearchParseException("[{}] must be set for date fields.", DecayFunctionBuilder.SCALE); } TimeValue val = TimeValue.parseTimeValue(scaleString, TimeValue.timeValueHours(24), DecayFunctionParser.class.getSimpleName() + ".scale"); double scale = val.getMillis(); val = TimeValue.parseTimeValue(offsetString, TimeValue.timeValueHours(24), DecayFunctionParser.class.getSimpleName() + ".offset"); double offset = val.getMillis(); IndexNumericFieldData numericFieldData = context.getForField(dateFieldType); return new NumericFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), numericFieldData, mode); } static class GeoFieldDataScoreFunction extends AbstractDistanceScoreFunction { private final GeoPoint origin; private final IndexGeoPointFieldData fieldData; private static final GeoDistance distFunction = GeoDistance.ARC; GeoFieldDataScoreFunction(GeoPoint origin, double scale, double decay, double offset, DecayFunction func, IndexGeoPointFieldData fieldData, MultiValueMode mode) { super(scale, decay, offset, func, mode); this.origin = origin; this.fieldData = fieldData; } @Override public boolean needsScores() { return false; } @Override protected NumericDoubleValues distance(LeafReaderContext context) { final MultiGeoPointValues geoPointValues = fieldData.load(context).getGeoPointValues(); return mode.select(new MultiValueMode.UnsortedNumericDoubleValues() { @Override public int docValueCount() { return geoPointValues.docValueCount(); } @Override public boolean advanceExact(int docId) throws IOException { return geoPointValues.advanceExact(docId); } @Override public double nextValue() throws IOException { GeoPoint other = geoPointValues.nextValue(); return Math.max(0.0d, distFunction.calculate(origin.lat(), origin.lon(), other.lat(), other.lon(), DistanceUnit.METERS) - offset); } }, 0.0); } @Override protected String getDistanceString(LeafReaderContext ctx, int docId) throws IOException { StringBuilder values = new StringBuilder(mode.name()); values.append(" of: ["); final MultiGeoPointValues geoPointValues = fieldData.load(ctx).getGeoPointValues(); if (geoPointValues.advanceExact(docId)) { final int num = geoPointValues.docValueCount(); for (int i = 0; i < num; i++) { GeoPoint value = geoPointValues.nextValue(); values.append("Math.max(arcDistance("); values.append(value).append("(=doc value),"); values.append(origin).append("(=origin)) - ").append(offset).append("(=offset), 0)"); if (i != num - 1) { values.append(", "); } } } else { values.append("0.0"); } values.append("]"); return values.toString(); } @Override protected String getFieldName() { return fieldData.getFieldName(); } @Override protected boolean doEquals(ScoreFunction other) { GeoFieldDataScoreFunction geoFieldDataScoreFunction = (GeoFieldDataScoreFunction) other; return super.doEquals(other) && Objects.equals(this.origin, geoFieldDataScoreFunction.origin); } @Override protected int doHashCode() { return Objects.hash(super.doHashCode(), origin); } } static class NumericFieldDataScoreFunction extends AbstractDistanceScoreFunction { private final IndexNumericFieldData fieldData; private final double origin; NumericFieldDataScoreFunction(double origin, double scale, double decay, double offset, DecayFunction func, IndexNumericFieldData fieldData, MultiValueMode mode) { super(scale, decay, offset, func, mode); this.fieldData = fieldData; this.origin = origin; } @Override public boolean needsScores() { return false; } @Override protected NumericDoubleValues distance(LeafReaderContext context) { final SortedNumericDoubleValues doubleValues = fieldData.load(context).getDoubleValues(); return mode.select(new MultiValueMode.UnsortedNumericDoubleValues() { @Override public int docValueCount() { return doubleValues.docValueCount(); } @Override public boolean advanceExact(int doc) throws IOException { return doubleValues.advanceExact(doc); } @Override public double nextValue() throws IOException { return Math.max(0.0d, Math.abs(doubleValues.nextValue() - origin) - offset); } }, 0.0); } @Override protected String getDistanceString(LeafReaderContext ctx, int docId) throws IOException { StringBuilder values = new StringBuilder(mode.name()); values.append("["); final SortedNumericDoubleValues doubleValues = fieldData.load(ctx).getDoubleValues(); if (doubleValues.advanceExact(docId)) { final int num = doubleValues.docValueCount(); for (int i = 0; i < num; i++) { double value = doubleValues.nextValue(); values.append("Math.max(Math.abs("); values.append(value).append("(=doc value) - "); values.append(origin).append("(=origin))) - "); values.append(offset).append("(=offset), 0)"); if (i != num - 1) { values.append(", "); } } } else { values.append("0.0"); } values.append("]"); return values.toString(); } @Override protected String getFieldName() { return fieldData.getFieldName(); } @Override protected boolean doEquals(ScoreFunction other) { NumericFieldDataScoreFunction numericFieldDataScoreFunction = (NumericFieldDataScoreFunction) other; if (super.doEquals(other) == false) { return false; } return Objects.equals(this.origin, numericFieldDataScoreFunction.origin); } } /** * This is the base class for scoring a single field. * * */ public abstract static class AbstractDistanceScoreFunction extends ScoreFunction { private final double scale; protected final double offset; private final DecayFunction func; protected final MultiValueMode mode; public AbstractDistanceScoreFunction(double userSuppiedScale, double decay, double offset, DecayFunction func, MultiValueMode mode) { super(CombineFunction.MULTIPLY); this.mode = mode; if (userSuppiedScale <= 0.0) { throw new IllegalArgumentException(FunctionScoreQueryBuilder.NAME + " : scale must be > 0.0."); } if (decay <= 0.0 || decay >= 1.0) { throw new IllegalArgumentException(FunctionScoreQueryBuilder.NAME + " : decay must be in the range [0..1]."); } this.scale = func.processScale(userSuppiedScale, decay); this.func = func; if (offset < 0.0d) { throw new IllegalArgumentException(FunctionScoreQueryBuilder.NAME + " : offset must be > 0.0"); } this.offset = offset; } /** * This function computes the distance from a defined origin. Since * the value of the document is read from the index, it cannot be * guaranteed that the value actually exists. If it does not, we assume * the user handles this case in the query and return 0. * */ protected abstract NumericDoubleValues distance(LeafReaderContext context); @Override public final LeafScoreFunction getLeafScoreFunction(final LeafReaderContext ctx) { final NumericDoubleValues distance = distance(ctx); return new LeafScoreFunction() { @Override public double score(int docId, float subQueryScore) throws IOException { if (distance.advanceExact(docId)) { return func.evaluate(distance.doubleValue(), scale); } else { return 0; } } @Override public Explanation explainScore(int docId, Explanation subQueryScore) throws IOException { if (distance.advanceExact(docId) == false) { return Explanation.noMatch("No value for the distance"); } return Explanation.match( CombineFunction.toFloat(score(docId, subQueryScore.getValue())), "Function for field " + getFieldName() + ":", func.explainFunction(getDistanceString(ctx, docId), distance.doubleValue(), scale)); } }; } protected abstract String getDistanceString(LeafReaderContext ctx, int docId) throws IOException; protected abstract String getFieldName(); @Override protected boolean doEquals(ScoreFunction other) { AbstractDistanceScoreFunction distanceScoreFunction = (AbstractDistanceScoreFunction) other; return Objects.equals(this.scale, distanceScoreFunction.scale) && Objects.equals(this.offset, distanceScoreFunction.offset) && Objects.equals(this.mode, distanceScoreFunction.mode) && Objects.equals(this.func, distanceScoreFunction.func) && Objects.equals(this.getFieldName(), distanceScoreFunction.getFieldName()); } @Override protected int doHashCode() { return Objects.hash(scale, offset, mode, func, getFieldName()); } } }