/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.scalar.annotations; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.operator.scalar.ParametricScalar; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.type.TypeSignature; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.lang.annotation.Annotation; import java.lang.reflect.Constructor; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Stream; import static com.facebook.presto.operator.scalar.annotations.OperatorValidator.validateOperator; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; public final class ScalarFromAnnotationsParser { private ScalarFromAnnotationsParser() {} public static List<SqlScalarFunction> parseFunctionDefinition(Class<?> clazz) { ImmutableList.Builder<SqlScalarFunction> builder = ImmutableList.builder(); for (ScalarHeaderAndMethods scalar : findScalarsInFunctionDefinitionClass(clazz)) { builder.add(parseParametricScalar(scalar, findConstructors(clazz))); } return builder.build(); } public static List<SqlScalarFunction> parseFunctionDefinitions(Class<?> clazz) { ImmutableList.Builder<SqlScalarFunction> builder = ImmutableList.builder(); for (ScalarHeaderAndMethods methods : findScalarsInFunctionSetClass(clazz)) { builder.add(parseParametricScalar(methods, findConstructors(clazz))); } return builder.build(); } private static List<ScalarHeaderAndMethods> findScalarsInFunctionDefinitionClass(Class<?> annotated) { ImmutableList.Builder<ScalarHeaderAndMethods> builder = ImmutableList.builder(); List<ScalarImplementationHeader> classHeaders = ScalarImplementationHeader.fromAnnotatedElement(annotated); checkArgument(!classHeaders.isEmpty(), "Class [%s] that defines function must be annotated with @ScalarFunction or @ScalarOperator", annotated.getName()); for (ScalarImplementationHeader header : classHeaders) { Set<Method> methods = findPublicMethodsWithAnnotation(annotated, SqlType.class, ScalarFunction.class, ScalarOperator.class); checkArgument(!methods.isEmpty(), "Parametric class [%s] does not have any annotated methods", annotated.getName()); for (Method method : methods) { checkArgument(method.getAnnotation(ScalarFunction.class) == null, "Parametric class method [%s] is annotated with @ScalarFunction", method); checkArgument(method.getAnnotation(ScalarOperator.class) == null, "Parametric class method [%s] is annotated with @ScalarOperator", method); } builder.add(new ScalarHeaderAndMethods(header, methods)); } return builder.build(); } private static List<ScalarHeaderAndMethods> findScalarsInFunctionSetClass(Class<?> annotated) { ImmutableList.Builder<ScalarHeaderAndMethods> builder = ImmutableList.builder(); for (Method method : findPublicMethodsWithAnnotation(annotated, SqlType.class, ScalarFunction.class, ScalarOperator.class)) { checkArgument((method.getAnnotation(ScalarFunction.class) != null) || (method.getAnnotation(ScalarOperator.class) != null), "Method [%s] annotated with @SqlType is missing @ScalarFunction or @ScalarOperator", method); for (ScalarImplementationHeader header : ScalarImplementationHeader.fromAnnotatedElement(method)) { builder.add(new ScalarHeaderAndMethods(header, ImmutableSet.of(method))); } } List<ScalarHeaderAndMethods> methods = builder.build(); checkArgument(!methods.isEmpty(), "Class [%s] does not have any methods annotated with @ScalarFunction or @ScalarOperator", annotated.getName()); return methods; } private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods scalar, Map<Set<TypeParameter>, Constructor<?>> constructors) { ImmutableMap.Builder<Signature, ScalarImplementation> exactImplementations = ImmutableMap.builder(); ImmutableList.Builder<ScalarImplementation> specializedImplementations = ImmutableList.builder(); ImmutableList.Builder<ScalarImplementation> genericImplementations = ImmutableList.builder(); Optional<Signature> signature = Optional.empty(); ScalarImplementationHeader header = scalar.getHeader(); checkArgument(!header.getName().isEmpty()); for (Method method : scalar.getMethods()) { ScalarImplementation implementation = ScalarImplementation.Parser.parseImplementation(header.getName(), method, constructors); if (implementation.getSignature().getTypeVariableConstraints().isEmpty() && implementation.getSignature().getArgumentTypes().stream().noneMatch(TypeSignature::isCalculated) && !implementation.getSignature().getReturnType().isCalculated()) { exactImplementations.put(implementation.getSignature(), implementation); continue; } if (implementation.hasSpecializedTypeParameters()) { specializedImplementations.add(implementation); } else { genericImplementations.add(implementation); } signature = signature.isPresent() ? signature : Optional.of(implementation.getSignature()); validateSignature(signature, implementation.getSignature()); } Signature scalarSignature = signature.orElseGet(() -> getOnlyElement(exactImplementations.build().keySet())); header.getOperatorType().ifPresent(operatorType -> validateOperator(operatorType, scalarSignature.getReturnType(), scalarSignature.getArgumentTypes())); ScalarImplementations implementations = new ScalarImplementations(exactImplementations.build(), specializedImplementations.build(), genericImplementations.build()); return new ParametricScalar(scalarSignature, header.getHeader(), implementations); } private static void validateSignature(Optional<Signature> signatureOld, Signature signatureNew) { if (!signatureOld.isPresent()) { return; } checkArgument(signatureOld.get().equals(signatureNew), "Implementations with type parameters must all have matching signatures. %s does not match %s", signatureOld.get(), signatureNew); } private static Map<Set<TypeParameter>, Constructor<?>> findConstructors(Class<?> clazz) { ImmutableMap.Builder<Set<TypeParameter>, Constructor<?>> builder = ImmutableMap.builder(); for (Constructor<?> constructor : clazz.getConstructors()) { Set<TypeParameter> typeParameters = new HashSet<>(); Stream.of(constructor.getAnnotationsByType(TypeParameter.class)) .forEach(typeParameters::add); builder.put(typeParameters, constructor); } return builder.build(); } @SafeVarargs private static Set<Method> findPublicMethodsWithAnnotation(Class<?> clazz, Class<? extends Annotation>... annotationClasses) { ImmutableSet.Builder<Method> methods = ImmutableSet.builder(); for (Method method : clazz.getDeclaredMethods()) { for (Annotation annotation : method.getAnnotations()) { for (Class<?> annotationClass : annotationClasses) { if (annotationClass.isInstance(annotation)) { checkArgument(Modifier.isPublic(method.getModifiers()), "Method [%s] annotated with @%s must be public", method, annotationClass.getSimpleName()); methods.add(method); } } } } return methods.build(); } private static class ScalarHeaderAndMethods { private final ScalarImplementationHeader header; private final Set<Method> methods; public ScalarHeaderAndMethods(ScalarImplementationHeader header, Set<Method> methods) { this.header = requireNonNull(header); this.methods = requireNonNull(methods); } public ScalarImplementationHeader getHeader() { return header; } public Set<Method> getMethods() { return methods; } } }