/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.metadata; import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import org.testng.annotations.Test; import static com.facebook.presto.metadata.FunctionKind.SCALAR; import static com.facebook.presto.metadata.Signature.comparableWithVariadicBound; import static com.facebook.presto.metadata.TestPolymorphicScalarFunction.TestMethods.VARCHAR_TO_BIGINT_RETURN_VALUE; import static com.facebook.presto.metadata.TestPolymorphicScalarFunction.TestMethods.VARCHAR_TO_VARCHAR_RETURN_VALUE; import static com.facebook.presto.spi.function.OperatorType.ADD; import static com.facebook.presto.spi.type.StandardTypes.VARCHAR; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static java.lang.Math.toIntExact; import static org.testng.Assert.assertEquals; public class TestPolymorphicScalarFunction { private static final TypeRegistry TYPE_REGISTRY = new TypeRegistry(); private static final FunctionRegistry REGISTRY = new FunctionRegistry(TYPE_REGISTRY, new BlockEncodingManager(TYPE_REGISTRY), new FeaturesConfig()); private static final Signature SIGNATURE = Signature.builder() .name("foo") .kind(SCALAR) .returnType(parseTypeSignature(StandardTypes.BIGINT)) .argumentTypes(parseTypeSignature("varchar(x)", ImmutableSet.of("x"))) .build(); private static final long INPUT_VARCHAR_LENGTH = 10; private static final String INPUT_VARCHAR_SIGNATURE = "varchar(" + INPUT_VARCHAR_LENGTH + ")"; private static final TypeSignature INPUT_VARCHAR_TYPE = parseTypeSignature(INPUT_VARCHAR_SIGNATURE); private static final Slice INPUT_SLICE = Slices.allocate(toIntExact(INPUT_VARCHAR_LENGTH)); private static final BoundVariables BOUND_VARIABLES = new BoundVariables( ImmutableMap.of("V", TYPE_REGISTRY.getType(INPUT_VARCHAR_TYPE)), ImmutableMap.of("x", INPUT_VARCHAR_LENGTH) ); @Test public void testSelectsMethodBasedOnArgumentTypes() throws Throwable { SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class) .signature(SIGNATURE) .implementation(b -> b.methods("bigintToBigintReturnExtraParameter")) .implementation(b -> b .methods("varcharToBigintReturnExtraParameter") .withExtraParameters(context -> ImmutableList.of(context.getLiteral("x"))) ) .build(); ScalarFunctionImplementation functionImplementation = function.specialize(BOUND_VARIABLES, 1, TYPE_REGISTRY, REGISTRY); assertEquals(functionImplementation.getMethodHandle().invoke(INPUT_SLICE), INPUT_VARCHAR_LENGTH); } @Test public void testSelectsMethodBasedOnReturnType() throws Throwable { SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class) .signature(SIGNATURE) .implementation(b -> b.methods("varcharToVarcharCreateSliceWithExtraParameterLength")) .implementation(b -> b .methods("varcharToBigintReturnExtraParameter") .withExtraParameters(context -> ImmutableList.of(42)) ) .build(); ScalarFunctionImplementation functionImplementation = function.specialize(BOUND_VARIABLES, 1, TYPE_REGISTRY, REGISTRY); assertEquals(functionImplementation.getMethodHandle().invoke(INPUT_SLICE), VARCHAR_TO_BIGINT_RETURN_VALUE); } @Test public void testSelectsFirstMethodBasedOnPredicate() throws Throwable { SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class) .signature(SIGNATURE) .implementation(b -> b .methods("varcharToBigint") .withPredicate(context -> true) ) .implementation(b -> b.methods("varcharToBigintReturnExtraParameter")) .build(); ScalarFunctionImplementation functionImplementation = function.specialize(BOUND_VARIABLES, 1, TYPE_REGISTRY, REGISTRY); assertEquals(functionImplementation.getMethodHandle().invoke(INPUT_SLICE), VARCHAR_TO_BIGINT_RETURN_VALUE); } @Test public void testSelectsSecondMethodBasedOnPredicate() throws Throwable { SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class) .signature(SIGNATURE) .implementation(b -> b .methods("varcharToBigintReturnExtraParameter") .withPredicate(context -> false) ) .implementation(b -> b.methods("varcharToBigint")) .build(); ScalarFunctionImplementation functionImplementation = function.specialize(BOUND_VARIABLES, 1, TYPE_REGISTRY, REGISTRY); assertEquals(functionImplementation.getMethodHandle().invoke(INPUT_SLICE), VARCHAR_TO_BIGINT_RETURN_VALUE); } @Test public void testSameLiteralInArgumentsAndReturnValue() throws Throwable { Signature signature = Signature.builder() .name("foo") .kind(SCALAR) .returnType(parseTypeSignature("varchar(x)", ImmutableSet.of("x"))) .argumentTypes(parseTypeSignature("varchar(x)", ImmutableSet.of("x"))) .build(); SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class) .signature(signature) .implementation(b -> b.methods("varcharToVarchar")) .build(); ScalarFunctionImplementation functionImplementation = function.specialize(BOUND_VARIABLES, 1, TYPE_REGISTRY, REGISTRY); Slice slice = (Slice) functionImplementation.getMethodHandle().invoke(INPUT_SLICE); assertEquals(slice, VARCHAR_TO_VARCHAR_RETURN_VALUE); } @Test public void testTypeParameters() throws Throwable { Signature signature = Signature.builder() .name("foo") .kind(SCALAR) .typeVariableConstraints(comparableWithVariadicBound("V", VARCHAR)) .returnType(parseTypeSignature("V")) .argumentTypes(parseTypeSignature("V")) .build(); SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class) .signature(signature) .implementation(b -> b.methods("varcharToVarchar")) .build(); ScalarFunctionImplementation functionImplementation = function.specialize(BOUND_VARIABLES, 1, TYPE_REGISTRY, REGISTRY); Slice slice = (Slice) functionImplementation.getMethodHandle().invoke(INPUT_SLICE); assertEquals(slice, VARCHAR_TO_VARCHAR_RETURN_VALUE); } @Test public void testSetsHiddenToTrueForOperators() { Signature signature = Signature.builder() .operatorType(ADD) .kind(SCALAR) .returnType(parseTypeSignature("varchar(x)", ImmutableSet.of("x"))) .argumentTypes(parseTypeSignature("varchar(x)", ImmutableSet.of("x"))) .build(); SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class) .signature(signature) .implementation(b -> b.methods("varcharToVarchar")) .build(); ScalarFunctionImplementation functionImplementation = function.specialize(BOUND_VARIABLES, 1, TYPE_REGISTRY, REGISTRY); } @Test(expectedExceptions = {IllegalStateException.class}, expectedExceptionsMessageRegExp = "method foo was not found in class com.facebook.presto.metadata.TestPolymorphicScalarFunction\\$TestMethods") public void testFailIfNotAllMethodsPresent() { SqlScalarFunction.builder(TestMethods.class) .signature(SIGNATURE) .implementation(b -> b.methods("bigintToBigintReturnExtraParameter")) .implementation(b -> b.methods("foo")) .build(); } @Test(expectedExceptions = {IllegalStateException.class}, expectedExceptionsMessageRegExp = "methods must be selected first") public void testFailNoMethodsAreSelectedWhenExtraParametersFunctionIsSet() { SqlScalarFunction.builder(TestMethods.class) .signature(SIGNATURE) .implementation(b -> b .withExtraParameters(context -> ImmutableList.of(42)) ) .build(); } @Test(expectedExceptions = {IllegalStateException.class}, expectedExceptionsMessageRegExp = "two matching methods \\(varcharToBigintReturnFirstExtraParameter and varcharToBigintReturnExtraParameter\\) for parameter types \\[varchar\\(10\\)\\]") public void testFailIfTwoMethodsWithSameArguments() { SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class) .signature(SIGNATURE) .implementation(b -> b.methods("varcharToBigintReturnFirstExtraParameter")) .implementation(b -> b.methods("varcharToBigintReturnExtraParameter")) .build(); function.specialize(BOUND_VARIABLES, 1, TYPE_REGISTRY, REGISTRY); } @Test(expectedExceptions = {IllegalStateException.class}, expectedExceptionsMessageRegExp = "two matching methods \\(varcharToBigintReturnFirstExtraParameter and varcharToBigintReturnExtraParameter\\) for parameter types \\[varchar\\(10\\)\\]") public void testFailIfTwoMethodsWithPredicatesWithSameArguments() { SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class) .signature(SIGNATURE) .implementation(b -> b .methods("varcharToBigintReturnFirstExtraParameter") .withPredicate(context -> true) ) .implementation(b -> b .methods("varcharToBigintReturnExtraParameter") .withPredicate(context -> true) ) .build(); function.specialize(BOUND_VARIABLES, 1, TYPE_REGISTRY, REGISTRY); } public static class TestMethods { static final Slice VARCHAR_TO_VARCHAR_RETURN_VALUE = Slices.utf8Slice("hello world"); static final long VARCHAR_TO_BIGINT_RETURN_VALUE = 42L; public static Slice varcharToVarchar(Slice varchar) { return VARCHAR_TO_VARCHAR_RETURN_VALUE; } public static long varcharToBigint(Slice varchar) { return VARCHAR_TO_BIGINT_RETURN_VALUE; } public static long varcharToBigintReturnExtraParameter(Slice varchar, long extraParameter) { return extraParameter; } public static long bigintToBigintReturnExtraParameter(long bigint, int extraParameter) { return bigint; } public static long varcharToBigintReturnFirstExtraParameter(Slice varchar, long extraParameter1, int extraParameter2) { return extraParameter1; } public static Slice varcharToVarcharCreateSliceWithExtraParameterLength(Slice string, int extraParameter) { return Slices.allocate(extraParameter); } } }