package io.crate.operation.scalar.arithmetic; import com.google.common.collect.Sets; import io.crate.analyze.symbol.Function; import io.crate.analyze.symbol.Symbol; import io.crate.metadata.*; import io.crate.operation.scalar.ScalarFunctionModule; import io.crate.types.DataType; import io.crate.types.DataTypes; import java.util.Arrays; import java.util.List; import java.util.Set; import java.util.function.BinaryOperator; public class ArithmeticFunctions { private final static Set<DataType> NUMERIC_WITH_DECIMAL = Sets.newHashSet(DataTypes.FLOAT, DataTypes.DOUBLE); public static class Names { public static final String ADD = "add"; public static final String SUBTRACT = "subtract"; public static final String MULTIPLY = "multiply"; public static final String DIVIDE = "divide"; public static final String POWER = "power"; public static final String MODULUS = "modulus"; } public static void register(ScalarFunctionModule module) { module.register(Names.ADD, new ArithmeticFunctionResolver( Names.ADD, "+", FunctionInfo.DETERMINISTIC_AND_COMPARISON_REPLACEMENT, (arg0, arg1) -> arg0 + arg1, (arg0, arg1) -> arg0 + arg1, (arg0, arg1) -> arg0 + arg1 )); module.register(Names.SUBTRACT, new ArithmeticFunctionResolver( Names.SUBTRACT, "-", FunctionInfo.DETERMINISTIC_ONLY, (arg0, arg1) -> arg0 - arg1, (arg0, arg1) -> arg0 - arg1, (arg0, arg1) -> arg0 - arg1 )); module.register(Names.MULTIPLY, new ArithmeticFunctionResolver( Names.MULTIPLY, "*", FunctionInfo.DETERMINISTIC_ONLY, (arg0, arg1) -> arg0 * arg1, (arg0, arg1) -> arg0 * arg1, (arg0, arg1) -> arg0 * arg1 )); module.register(Names.DIVIDE, new ArithmeticFunctionResolver( Names.DIVIDE, "/", FunctionInfo.DETERMINISTIC_ONLY, (arg0, arg1) -> arg0 / arg1, (arg0, arg1) -> arg0 / arg1, (arg0, arg1) -> arg0 / arg1 )); module.register(Names.MODULUS, new ArithmeticFunctionResolver( Names.MODULUS, "%", FunctionInfo.DETERMINISTIC_ONLY, (arg0, arg1) -> arg0 % arg1, (arg0, arg1) -> arg0 % arg1, (arg0, arg1) -> arg0 % arg1 )); module.register(Names.POWER, new DoubleFunctionResolver( Names.POWER, (arg0, arg1) -> Math.pow(arg0, arg1) )); } final static class DoubleFunctionResolver extends BaseFunctionResolver { private static final Signature.ArgMatcher ARITHMETIC_TYPE = Signature.ArgMatcher.of( DataTypes.NUMERIC_PRIMITIVE_TYPES::contains, DataTypes.TIMESTAMP::equals); private final String name; private final BinaryOperator<Double> doubleFunction; DoubleFunctionResolver(String name, BinaryOperator<Double> doubleFunction) { super(Signature.of(ARITHMETIC_TYPE, ARITHMETIC_TYPE)); this.name = name; this.doubleFunction = doubleFunction; } @Override public FunctionImplementation getForTypes(List<DataType> dataTypes) throws IllegalArgumentException { return new BinaryScalar<>(doubleFunction, name, DataTypes.DOUBLE, FunctionInfo.DETERMINISTIC_ONLY); } } final static class ArithmeticFunctionResolver extends BaseFunctionResolver { private static final Signature.ArgMatcher ARITHMETIC_TYPE = Signature.ArgMatcher.of( DataTypes.NUMERIC_PRIMITIVE_TYPES::contains, DataTypes.TIMESTAMP::equals); private final String name; private final String operator; private final Set<FunctionInfo.Feature> features; private final BinaryOperator<Double> doubleFunction; private final BinaryOperator<Long> longFunction; private final BinaryOperator<Float> floatFunction; ArithmeticFunctionResolver(String name, String operator, Set<FunctionInfo.Feature> features, BinaryOperator<Double> doubleFunction, BinaryOperator<Long> longFunction, BinaryOperator<Float> floatFunction) { super(Signature.of(ARITHMETIC_TYPE, ARITHMETIC_TYPE)); this.name = name; this.operator = operator; this.doubleFunction = doubleFunction; this.longFunction = longFunction; this.floatFunction = floatFunction; this.features = features; } @Override public FunctionImplementation getForTypes(List<DataType> dataTypes) throws IllegalArgumentException { BinaryScalar<?> scalar; if (containsTypesWithDecimal(dataTypes)) { if (containsDouble(dataTypes)) { scalar = new BinaryScalar<>(doubleFunction, name, DataTypes.DOUBLE, features); } else { scalar = new BinaryScalar<>(floatFunction, name, DataTypes.FLOAT, features); } } else { scalar = new BinaryScalar<>(longFunction, name, DataTypes.LONG, features); } return Scalar.withOperator(scalar, operator); } } public static Function of(String name, Symbol first, Symbol second, Set<FunctionInfo.Feature> features) { List<DataType> argumentTypes = Arrays.asList(first.valueType(), second.valueType()); if (containsTypesWithDecimal(argumentTypes)) { return new Function( genDoubleInfo(name, argumentTypes, features), Arrays.asList(first, second)); } return new Function( genLongInfo(name, argumentTypes, features), Arrays.asList(first, second)); } static boolean containsTypesWithDecimal(List<DataType> dataTypes) { for (DataType dataType : dataTypes) { if (NUMERIC_WITH_DECIMAL.contains(dataType)) { return true; } } return false; } static boolean containsDouble(List<DataType> dataTypes) { for (DataType dataType : dataTypes) { if (dataType.equals(DataTypes.DOUBLE)) { return true; } } return false; } static FunctionInfo genDoubleInfo(String functionName, List<DataType> dataTypes, Set<FunctionInfo.Feature> features) { return new FunctionInfo(new FunctionIdent(functionName, dataTypes), DataTypes.DOUBLE, FunctionInfo.Type.SCALAR, features); } static FunctionInfo genLongInfo(String functionName, List<DataType> dataTypes, Set<FunctionInfo.Feature> features) { return new FunctionInfo(new FunctionIdent(functionName, dataTypes), DataTypes.LONG, FunctionInfo.Type.SCALAR, features); } }