/*
* Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
* license agreements. See the NOTICE file distributed with this work for
* additional information regarding copyright ownership. Crate licenses
* this file to you under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* However, if you have executed another commercial license agreement
* with Crate these terms will supersede the license and you may use the
* software solely pursuant to the terms of the relevant commercial agreement.
*/
package io.crate.operation.scalar.geo;
import io.crate.analyze.symbol.Function;
import io.crate.analyze.symbol.Literal;
import io.crate.analyze.symbol.Symbol;
import io.crate.analyze.symbol.format.SymbolFormatter;
import io.crate.metadata.*;
import io.crate.data.Input;
import io.crate.operation.scalar.ScalarFunctionModule;
import io.crate.types.ArrayType;
import io.crate.types.DataType;
import io.crate.types.DataTypes;
import org.elasticsearch.common.geo.GeoUtils;
import java.util.Arrays;
import java.util.List;
public class DistanceFunction extends Scalar<Double, Object> {
public static final String NAME = "distance";
private final static Signature.ArgMatcher ALLOWED_TYPE = Signature.ArgMatcher.of(
DataTypes.STRING, DataTypes.GEO_POINT, new ArrayType(DataTypes.DOUBLE));
private final FunctionInfo info;
private final static FunctionInfo geoPointInfo = genInfo(Arrays.asList(DataTypes.GEO_POINT, DataTypes.GEO_POINT));
public static void register(ScalarFunctionModule module) {
module.register(NAME, new BaseFunctionResolver(Signature.of(ALLOWED_TYPE, ALLOWED_TYPE)) {
@Override
public FunctionImplementation getForTypes(List<DataType> dataTypes) throws IllegalArgumentException {
return new DistanceFunction(genInfo(dataTypes));
}
});
}
private static FunctionInfo genInfo(List<DataType> argumentTypes) {
return new FunctionInfo(new FunctionIdent(NAME, argumentTypes), DataTypes.DOUBLE);
}
DistanceFunction(FunctionInfo info) {
this.info = info;
}
@Override
public FunctionInfo info() {
return info;
}
@Override
public Double evaluate(Input[] args) {
assert args.length == 2 : "number of args must be 2";
return evaluate(args[0], args[1]);
}
public Double evaluate(Input arg1, Input arg2) {
Object value1 = arg1.value();
if (value1 == null) {
return null;
}
Object value2 = arg2.value();
if (value2 == null) {
return null;
}
double sourceLongitude;
double sourceLatitude;
double targetLongitude;
double targetLatitude;
// need to handle list also - because e.g. ESSearchTask returns geo_points as list
if (value1 instanceof List) {
sourceLongitude = (Double) ((List) value1).get(0);
sourceLatitude = (Double) ((List) value1).get(1);
} else {
sourceLongitude = ((Double[]) value1)[0];
sourceLatitude = ((Double[]) value1)[1];
}
if (value2 instanceof List) {
targetLongitude = (Double) ((List) value2).get(0);
targetLatitude = (Double) ((List) value2).get(1);
} else {
targetLongitude = ((Double[]) value2)[0];
targetLatitude = ((Double[]) value2)[1];
}
return GeoUtils.arcDistance(sourceLatitude, sourceLongitude, targetLatitude, targetLongitude);
}
@Override
public Symbol normalizeSymbol(Function symbol, TransactionContext transactionContext) {
Symbol arg1 = symbol.arguments().get(0);
Symbol arg2 = symbol.arguments().get(1);
DataType arg1Type = arg1.valueType();
DataType arg2Type = arg2.valueType();
boolean arg1IsReference = true;
boolean literalConverted = false;
short numLiterals = 0;
if (arg1.symbolType().isValueSymbol()) {
numLiterals++;
arg1IsReference = false;
if (!arg1Type.equals(DataTypes.GEO_POINT)) {
literalConverted = true;
arg1 = Literal.convert(arg1, DataTypes.GEO_POINT);
}
} else {
validateType(arg1, arg1Type);
}
if (arg2.symbolType().isValueSymbol()) {
numLiterals++;
if (!arg2Type.equals(DataTypes.GEO_POINT)) {
literalConverted = true;
arg2 = Literal.convert(arg2, DataTypes.GEO_POINT);
}
} else {
validateType(arg2, arg2Type);
}
if (numLiterals == 2) {
return Literal.of(evaluate((Input) arg1, (Input) arg2));
}
// ensure reference is the first argument.
if (!arg1IsReference) {
return new Function(geoPointInfo, Arrays.asList(arg2, arg1));
}
if (literalConverted) {
return new Function(geoPointInfo, Arrays.asList(arg1, arg2));
}
return symbol;
}
private void validateType(Symbol symbol, DataType dataType) {
if (!dataType.equals(DataTypes.GEO_POINT)) {
throw new IllegalArgumentException(SymbolFormatter.format(
"Cannot convert %s to a geo point", symbol));
}
}
}