/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.metadata; import com.facebook.presto.metadata.SqlScalarFunctionBuilder.MethodsGroup; import com.facebook.presto.metadata.SqlScalarFunctionBuilder.SpecializeContext; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.util.Reflection; import com.google.common.primitives.Primitives; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.reflect.Method; import java.util.List; import java.util.Optional; import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables; import static com.facebook.presto.type.TypeUtils.resolveTypes; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; class PolymorphicScalarFunction extends SqlScalarFunction { private final String description; private final boolean hidden; private final boolean deterministic; private final boolean nullableResult; private final List<Boolean> nullableArguments; private final List<Boolean> nullFlags; private final List<MethodsGroup> methodsGroups; PolymorphicScalarFunction( Signature signature, String description, boolean hidden, boolean deterministic, boolean nullableResult, List<Boolean> nullableArguments, List<Boolean> nullFlags, List<MethodsGroup> methodsGroups) { super(signature); this.description = description; this.hidden = hidden; this.deterministic = deterministic; this.nullableResult = nullableResult; this.nullableArguments = requireNonNull(nullableArguments, "nullableArguments is null"); this.nullFlags = requireNonNull(nullFlags, "nullFlags is null"); this.methodsGroups = requireNonNull(methodsGroups, "methodsWithExtraParametersFunctions is null"); } @Override public boolean isHidden() { return hidden; } @Override public boolean isDeterministic() { return deterministic; } @Override public String getDescription() { return description; } @Override public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { List<TypeSignature> resolvedParameterTypeSignatures = applyBoundVariables(getSignature().getArgumentTypes(), boundVariables); List<Type> resolvedParameterTypes = resolveTypes(resolvedParameterTypeSignatures, typeManager); TypeSignature resolvedReturnTypeSignature = applyBoundVariables(getSignature().getReturnType(), boundVariables); Type resolvedReturnType = typeManager.getType(resolvedReturnTypeSignature); SpecializeContext context = new SpecializeContext(boundVariables, resolvedParameterTypes, resolvedReturnType, typeManager, functionRegistry); Optional<Method> matchingMethod = Optional.empty(); Optional<MethodsGroup> matchingMethodsGroup = Optional.empty(); for (MethodsGroup candidateMethodsGroup : methodsGroups) { for (Method candidateMethod : candidateMethodsGroup.getMethods()) { if (matchesParameterAndReturnTypes(candidateMethod, resolvedParameterTypes, resolvedReturnType) && predicateIsTrue(candidateMethodsGroup, context)) { if (matchingMethod.isPresent()) { if (onlyFirstMatchedMethodHasPredicate(matchingMethodsGroup.get(), candidateMethodsGroup)) { continue; } throw new IllegalStateException("two matching methods (" + matchingMethod.get().getName() + " and " + candidateMethod.getName() + ") for parameter types " + resolvedParameterTypeSignatures); } matchingMethod = Optional.of(candidateMethod); matchingMethodsGroup = Optional.of(candidateMethodsGroup); } } } checkState(matchingMethod.isPresent(), "no matching method for parameter types %s", resolvedParameterTypes); List<Object> extraParameters = computeExtraParameters(matchingMethodsGroup.get(), context); MethodHandle matchingMethodHandle = applyExtraParameters(matchingMethod.get(), extraParameters); return new ScalarFunctionImplementation( nullableResult, nullableArguments, nullFlags, matchingMethodHandle, deterministic); } private boolean matchesParameterAndReturnTypes(Method method, List<Type> resolvedTypes, Type returnType) { checkState(method.getParameterCount() >= resolvedTypes.size(), "method %s has not enough arguments: %s (should have at least %s)", method.getName(), method.getParameterCount(), resolvedTypes.size()); Class<?>[] methodParameterJavaTypes = method.getParameterTypes(); for (int i = 0, methodParameterIndex = 0; i < resolvedTypes.size(); i++) { Class<?> type = getNullAwareContainerType(resolvedTypes.get(i).getJavaType(), nullableArguments.get(i) && !nullFlags.get(i)); if (!methodParameterJavaTypes[methodParameterIndex].equals(type)) { return false; } methodParameterIndex += nullFlags.get(i) ? 2 : 1; } return method.getReturnType().equals(getNullAwareContainerType(returnType.getJavaType(), nullableResult)); } private static boolean onlyFirstMatchedMethodHasPredicate(MethodsGroup matchingMethodsGroup, MethodsGroup methodsGroup) { return matchingMethodsGroup.getPredicate().isPresent() && !methodsGroup.getPredicate().isPresent(); } private static boolean predicateIsTrue(MethodsGroup methodsGroup, SpecializeContext context) { return methodsGroup.getPredicate().map(predicate -> predicate.test(context)).orElse(true); } private static List<Object> computeExtraParameters(MethodsGroup methodsGroup, SpecializeContext context) { return methodsGroup.getExtraParametersFunction().map(function -> function.apply(context)).orElse(emptyList()); } private int getNullFlagsCount() { int count = 0; for (boolean flag : nullFlags) { if (flag) { count++; } } return count; } private MethodHandle applyExtraParameters(Method matchingMethod, List<Object> extraParameters) { Signature signature = getSignature(); int expectedArgumentsCount = signature.getArgumentTypes().size() + getNullFlagsCount() + extraParameters.size(); int matchingMethodArgumentCount = matchingMethod.getParameterCount(); checkState(matchingMethodArgumentCount == expectedArgumentsCount, "method %s has invalid number of arguments: %s (should have %s)", matchingMethod.getName(), matchingMethodArgumentCount, expectedArgumentsCount); MethodHandle matchingMethodHandle = Reflection.methodHandle(matchingMethod); matchingMethodHandle = MethodHandles.insertArguments( matchingMethodHandle, matchingMethodArgumentCount - extraParameters.size(), extraParameters.toArray()); return matchingMethodHandle; } private static Class<?> getNullAwareContainerType(Class<?> clazz, boolean nullable) { if (nullable) { return Primitives.wrap(clazz); } checkArgument(clazz != void.class); return clazz; } }