/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.metadata; import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.operator.scalar.CustomFunctions; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.spi.block.BlockEncodingSerde; import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.type.TypeRegistry; import com.google.common.base.Functions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; import java.lang.invoke.MethodHandles; import java.util.Collections; import java.util.List; import static com.facebook.presto.metadata.FunctionKind.SCALAR; import static com.facebook.presto.metadata.FunctionRegistry.getMagicLiteralFunctionSignature; import static com.facebook.presto.metadata.FunctionRegistry.mangleOperatorName; import static com.facebook.presto.metadata.FunctionRegistry.unmangleOperator; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.type.HyperLogLogType.HYPER_LOG_LOG; import static com.facebook.presto.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; import static com.facebook.presto.type.TypeUtils.resolveTypes; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Lists.transform; import static java.lang.String.format; import static java.util.stream.Collectors.toList; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; public class TestFunctionRegistry { @Test public void testIdentityCast() { TypeRegistry typeManager = new TypeRegistry(); FunctionRegistry registry = new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); Signature exactOperator = registry.getCoercion(HYPER_LOG_LOG, HYPER_LOG_LOG); assertEquals(exactOperator.getName(), mangleOperatorName(OperatorType.CAST.name())); assertEquals(transform(exactOperator.getArgumentTypes(), Functions.toStringFunction()), ImmutableList.of(StandardTypes.HYPER_LOG_LOG)); assertEquals(exactOperator.getReturnType().getBase(), StandardTypes.HYPER_LOG_LOG); } @Test public void testExactMatchBeforeCoercion() { TypeRegistry typeManager = new TypeRegistry(); FunctionRegistry registry = new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); boolean foundOperator = false; for (SqlFunction function : registry.listOperators()) { OperatorType operatorType = unmangleOperator(function.getSignature().getName()); if (operatorType == OperatorType.CAST || operatorType == OperatorType.SATURATED_FLOOR_CAST) { continue; } if (!function.getSignature().getTypeVariableConstraints().isEmpty()) { continue; } if (function.getSignature().getArgumentTypes().stream().anyMatch(TypeSignature::isCalculated)) { continue; } Signature exactOperator = registry.resolveOperator(operatorType, resolveTypes(function.getSignature().getArgumentTypes(), typeManager)); assertEquals(exactOperator, function.getSignature()); foundOperator = true; } assertTrue(foundOperator); } @Test public void testMagicLiteralFunction() { Signature signature = getMagicLiteralFunctionSignature(TIMESTAMP_WITH_TIME_ZONE); assertEquals(signature.getName(), "$literal$timestamp with time zone"); assertEquals(signature.getArgumentTypes(), ImmutableList.of(parseTypeSignature(StandardTypes.BIGINT))); assertEquals(signature.getReturnType().getBase(), StandardTypes.TIMESTAMP_WITH_TIME_ZONE); TypeRegistry typeManager = new TypeRegistry(); FunctionRegistry registry = new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); Signature function = registry.resolveFunction(QualifiedName.of(signature.getName()), fromTypeSignatures(signature.getArgumentTypes())); assertEquals(function.getArgumentTypes(), ImmutableList.of(parseTypeSignature(StandardTypes.BIGINT))); assertEquals(signature.getReturnType().getBase(), StandardTypes.TIMESTAMP_WITH_TIME_ZONE); } @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "\\QFunction already registered: custom_add(bigint,bigint):bigint\\E") public void testDuplicateFunctions() { List<SqlFunction> functions = new FunctionListBuilder() .scalars(CustomFunctions.class) .getFunctions() .stream() .filter(input -> input.getSignature().getName().equals("custom_add")) .collect(toImmutableList()); TypeRegistry typeManager = new TypeRegistry(); FunctionRegistry registry = new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); registry.addFunctions(functions); registry.addFunctions(functions); } @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "'sum' is both an aggregation and a scalar function") public void testConflictingScalarAggregation() throws Exception { List<SqlFunction> functions = new FunctionListBuilder() .scalars(ScalarSum.class) .getFunctions(); TypeRegistry typeManager = new TypeRegistry(); FunctionRegistry registry = new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); registry.addFunctions(functions); } @Test public void testListingHiddenFunctions() throws Exception { TypeRegistry typeManager = new TypeRegistry(); FunctionRegistry registry = new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); List<SqlFunction> functions = registry.list(); List<String> names = transform(functions, input -> input.getSignature().getName()); assertTrue(names.contains("length"), "Expected function names " + names + " to contain 'length'"); assertTrue(names.contains("stddev"), "Expected function names " + names + " to contain 'stddev'"); assertTrue(names.contains("rank"), "Expected function names " + names + " to contain 'rank'"); assertFalse(names.contains("like"), "Expected function names " + names + " not to contain 'like'"); } @Test public void testResolveFunctionByExactMatch() throws Exception { assertThatResolveFunction() .among(functionSignature("bigint", "bigint")) .forParameters("bigint", "bigint") .returns(functionSignature("bigint", "bigint")); } @Test public void testResolveTypeParametrizedFunction() throws Exception { assertThatResolveFunction() .among(functionSignature(ImmutableList.of("T", "T"), "boolean", ImmutableList.of(typeVariable("T")))) .forParameters("bigint", "bigint") .returns(functionSignature("bigint", "bigint")); } @Test public void testResolveFunctionWithCoercion() throws Exception { assertThatResolveFunction() .among( functionSignature("decimal(p,s)", "double"), functionSignature("decimal(p,s)", "decimal(p,s)"), functionSignature("double", "double") ) .forParameters("bigint", "bigint") .returns(functionSignature("decimal(19,0)", "decimal(19,0)")); } @Test public void testAmbiguousCallWithNoCoercion() throws Exception { assertThatResolveFunction() .among( functionSignature("decimal(p,s)", "decimal(p,s)"), functionSignature(ImmutableList.of("T", "T"), "boolean", ImmutableList.of(typeVariable("T"))) ) .forParameters("decimal(3,1)", "decimal(3,1)") .returns(functionSignature("decimal(3,1)", "decimal(3,1)")); } @Test public void testAmbiguousCallWithCoercion() throws Exception { assertThatResolveFunction() .among( functionSignature("decimal(p,s)", "double"), functionSignature("double", "decimal(p,s)") ) .forParameters("bigint", "bigint") .failsWithMessage("Could not choose a best candidate operator. Explicit type casts must be added."); } @Test public void testResolveFunctionWithCoercionInTypes() throws Exception { assertThatResolveFunction() .among( functionSignature("array(decimal(p,s))", "array(double)"), functionSignature("array(decimal(p,s))", "array(decimal(p,s))"), functionSignature("array(double)", "array(double)") ) .forParameters("array(bigint)", "array(bigint)") .returns(functionSignature("array(decimal(19,0))", "array(decimal(19,0))")); } @Test public void testResolveFunctionWithVariableArity() throws Exception { assertThatResolveFunction() .among( functionSignature("double", "double", "double"), functionSignature("decimal(p,s)").setVariableArity(true) ) .forParameters("bigint", "bigint", "bigint") .returns(functionSignature("decimal(19,0)", "decimal(19,0)", "decimal(19,0)")); assertThatResolveFunction() .among( functionSignature("double", "double", "double"), functionSignature("bigint").setVariableArity(true) ) .forParameters("bigint", "bigint", "bigint") .returns(functionSignature("bigint", "bigint", "bigint")); } @Test public void testResolveFunctionWithVariadicBound() throws Exception { assertThatResolveFunction() .among( functionSignature("bigint", "bigint", "bigint"), functionSignature( ImmutableList.of("T1", "T2", "T3"), "boolean", ImmutableList.of(Signature.withVariadicBound("T1", "decimal"), Signature.withVariadicBound("T2", "decimal"), Signature.withVariadicBound("T3", "decimal"))) ) .forParameters("unknown", "bigint", "bigint") .returns(functionSignature("bigint", "bigint", "bigint")); } @Test public void testResolveFunctionForUnknown() throws Exception { assertThatResolveFunction() .among( functionSignature("bigint") ) .forParameters("unknown") .returns(functionSignature("bigint")); // when coercion between the types exist, and the most specific function can be determined with the main algorithm assertThatResolveFunction() .among( functionSignature("bigint"), functionSignature("integer") ) .forParameters("unknown") .returns(functionSignature("integer")); // function that requires only unknown coercion must be preferred assertThatResolveFunction() .among( functionSignature("bigint", "bigint"), functionSignature("integer", "integer") ) .forParameters("unknown", "bigint") .returns(functionSignature("bigint", "bigint")); // when coercion between the types doesn't exist, but the return type is the same, so the random function must be choosen assertThatResolveFunction() .among( functionSignature(ImmutableList.of("JoniRegExp"), "boolean"), functionSignature(ImmutableList.of("integer"), "boolean") ) .forParameters("unknown") // any function can be selected, but to make it deterministic we sort function signatures alphabetically .returns(functionSignature("integer")); // when the return type is different assertThatResolveFunction() .among( functionSignature(ImmutableList.of("JoniRegExp"), "JoniRegExp"), functionSignature(ImmutableList.of("integer"), "integer") ) .forParameters("unknown") .failsWithMessage("Could not choose a best candidate operator. Explicit type casts must be added."); } private SignatureBuilder functionSignature(String... argumentTypes) { return functionSignature(ImmutableList.copyOf(argumentTypes), "boolean"); } private static SignatureBuilder functionSignature(List<String> arguments, String returnType) { return functionSignature(arguments, returnType, ImmutableList.of()); } private static SignatureBuilder functionSignature(List<String> arguments, String returnType, List<TypeVariableConstraint> typeVariableConstraints) { ImmutableSet<String> literalParameters = ImmutableSet.of("p", "s", "p1", "s1", "p2", "s2", "p3", "s3"); List<TypeSignature> argumentSignatures = arguments.stream() .map((signature) -> TypeSignature.parseTypeSignature(signature, literalParameters)) .collect(toImmutableList()); return new SignatureBuilder() .returnType(TypeSignature.parseTypeSignature(returnType, literalParameters)) .argumentTypes(argumentSignatures) .typeVariableConstraints(typeVariableConstraints) .kind(SCALAR); } private static ResolveFunctionAssertion assertThatResolveFunction() { return new ResolveFunctionAssertion(); } private static class ResolveFunctionAssertion { private static final String TEST_FUNCTION_NAME = "TEST_FUNCTION_NAME"; private final TypeRegistry typeRegistry = new TypeRegistry(); private final BlockEncodingSerde blockEncoding = new BlockEncodingManager(typeRegistry); private List<SignatureBuilder> functionSignatures = ImmutableList.of(); private List<TypeSignature> parameterTypes = ImmutableList.of(); public ResolveFunctionAssertion among(SignatureBuilder... functionSignatures) { this.functionSignatures = ImmutableList.copyOf(functionSignatures); return this; } public ResolveFunctionAssertion forParameters(String... parameters) { this.parameterTypes = parseTypeSignatures(parameters); return this; } public ResolveFunctionAssertion returns(SignatureBuilder functionSignature) { Signature expectedSignature = functionSignature.name(TEST_FUNCTION_NAME).build(); Signature actualSignature = resolveSignature(); assertEquals(actualSignature, expectedSignature); return this; } public ResolveFunctionAssertion failsWithMessage(String... messages) { try { resolveSignature(); fail("didn't fail as expected"); } catch (RuntimeException e) { String actualMessage = e.getMessage(); for (String expectedMessage : messages) { if (!actualMessage.contains(expectedMessage)) { fail(format("%s doesn't contain %s", actualMessage, expectedMessage)); } } } return this; } private Signature resolveSignature() { FunctionRegistry functionRegistry = new FunctionRegistry(typeRegistry, blockEncoding, new FeaturesConfig()); functionRegistry.addFunctions(createFunctionsFromSignatures()); return functionRegistry.resolveFunction(QualifiedName.of(TEST_FUNCTION_NAME), fromTypeSignatures(parameterTypes)); } private List<SqlFunction> createFunctionsFromSignatures() { ImmutableList.Builder<SqlFunction> functions = ImmutableList.builder(); for (SignatureBuilder functionSignature : functionSignatures) { Signature signature = functionSignature.name(TEST_FUNCTION_NAME).build(); functions.add(new SqlScalarFunction(signature) { @Override public ScalarFunctionImplementation specialize( BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { return new ScalarFunctionImplementation(false, Collections.nCopies(arity, Boolean.FALSE), MethodHandles.identity(Void.class), true); } @Override public boolean isHidden() { return false; } @Override public boolean isDeterministic() { return false; } @Override public String getDescription() { return "testing function that does nothing"; } }); } return functions.build(); } private static List<TypeSignature> parseTypeSignatures(String... signatures) { return ImmutableList.copyOf(signatures) .stream() .map(TypeSignature::parseTypeSignature) .collect(toList()); } } public static final class ScalarSum { private ScalarSum() {} @ScalarFunction @SqlType(StandardTypes.BIGINT) public static long sum(@SqlType(StandardTypes.BIGINT) long a, @SqlType(StandardTypes.BIGINT) long b) { return a + b; } } }