/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;
import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.metadata.FunctionKind;
import com.facebook.presto.metadata.LongVariableConstraint;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.operator.aggregation.state.StateCompiler;
import com.facebook.presto.spi.function.AccumulatorState;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.CombineFunction;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.InputFunction;
import com.facebook.presto.spi.function.LiteralParameters;
import com.facebook.presto.spi.function.OutputFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.type.Constraint;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import javax.annotation.Nullable;
import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.util.Objects.requireNonNull;
public class AggregationCompiler
{
private AggregationCompiler()
{
}
// This function should only be used for function matching for testing purposes.
// General purpose function matching is done through FunctionRegistry.
@VisibleForTesting
public static BindableAggregationFunction generateAggregationBindableFunction(Class<?> clazz)
{
List<BindableAggregationFunction> aggregations = generateBindableAggregationFunctions(clazz);
checkArgument(aggregations.size() == 1, "More than one aggregation function found");
return aggregations.get(0);
}
// This function should only be used for function matching for testing purposes.
// General purpose function matching is done through FunctionRegistry.
public static BindableAggregationFunction generateAggregationBindableFunction(Class<?> clazz, TypeSignature returnType, List<TypeSignature> argumentTypes)
{
requireNonNull(returnType, "returnType is null");
requireNonNull(argumentTypes, "argumentTypes is null");
for (BindableAggregationFunction aggregation : generateBindableAggregationFunctions(clazz)) {
if (aggregation.getSignature().getReturnType().equals(returnType) &&
aggregation.getSignature().getArgumentTypes().equals(argumentTypes)) {
return aggregation;
}
}
throw new IllegalArgumentException(String.format("No method with return type %s and arguments %s", returnType, argumentTypes));
}
public static List<BindableAggregationFunction> generateBindableAggregationFunctions(Class<?> aggregationDefinition)
{
AggregationFunction aggregationAnnotation = aggregationDefinition.getAnnotation(AggregationFunction.class);
requireNonNull(aggregationAnnotation, "aggregationAnnotation is null");
DynamicClassLoader classLoader = new DynamicClassLoader(aggregationDefinition.getClassLoader());
ImmutableList.Builder<BindableAggregationFunction> builder = ImmutableList.builder();
for (Class<?> stateClass : getStateClasses(aggregationDefinition)) {
AccumulatorStateSerializer<?> stateSerializer = StateCompiler.generateStateSerializer(stateClass, classLoader);
for (Method outputFunction : getOutputFunctions(aggregationDefinition, stateClass)) {
for (Method inputFunction : getInputFunctions(aggregationDefinition, stateClass)) {
List<LongVariableConstraint> longVariableConstraints = parseLongVariableConstraints(inputFunction);
for (String name : getNames(outputFunction, aggregationAnnotation)) {
List<TypeSignature> inputTypes = getInputTypesSignatures(inputFunction);
TypeSignature outputType = TypeSignature.parseTypeSignature(outputFunction.getAnnotation(OutputFunction.class).value());
builder.add(
new BindableAggregationFunction(
new Signature(
name,
FunctionKind.AGGREGATE,
ImmutableList.of(), // TODO parse constrains from annotations
longVariableConstraints,
outputType,
inputTypes,
false),
getDescription(aggregationDefinition, outputFunction),
aggregationAnnotation.decomposable(),
aggregationDefinition,
stateClass,
inputFunction,
outputFunction));
}
}
}
}
return builder.build();
}
private static List<LongVariableConstraint> parseLongVariableConstraints(Method inputFunction)
{
return Stream.of(inputFunction.getAnnotationsByType(Constraint.class))
.map(annotation -> new LongVariableConstraint(annotation.variable(), annotation.expression()))
.collect(toImmutableList());
}
public static boolean isParameterNullable(Annotation[] annotations)
{
return Arrays.asList(annotations).stream().anyMatch(annotation -> annotation instanceof NullablePosition);
}
public static boolean isParameterBlock(Annotation[] annotations)
{
return Arrays.asList(annotations).stream().anyMatch(annotation -> annotation instanceof BlockPosition);
}
private static List<String> getNames(@Nullable Method outputFunction, AggregationFunction aggregationAnnotation)
{
List<String> defaultNames = ImmutableList.<String>builder().add(aggregationAnnotation.value()).addAll(Arrays.asList(aggregationAnnotation.alias())).build();
if (outputFunction == null) {
return defaultNames;
}
AggregationFunction annotation = outputFunction.getAnnotation(AggregationFunction.class);
if (annotation == null) {
return defaultNames;
}
else {
return ImmutableList.<String>builder().add(annotation.value()).addAll(Arrays.asList(annotation.alias())).build();
}
}
public static Method getCombineFunction(Class<?> clazz, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> combineFunctions = findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class).stream()
.filter(method -> method.getParameterTypes()[findAggregationStateParamId(method, 0)] == stateClass)
.filter(method -> method.getParameterTypes()[findAggregationStateParamId(method, 1)] == stateClass)
.collect(toImmutableList());
checkArgument(combineFunctions.size() == 1, String.format("There must be exactly one @CombineFunction in class %s for the @AggregationState %s ", clazz.toGenericString(), stateClass.toGenericString()));
return getOnlyElement(combineFunctions);
}
private static List<Method> getOutputFunctions(Class<?> clazz, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> outputFunctions = findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class).stream()
.filter(method -> method.getParameterTypes()[findAggregationStateParamId(method)] == stateClass)
.collect(toImmutableList());
checkArgument(!outputFunctions.isEmpty(), "Aggregation has no output functions");
return outputFunctions;
}
private static List<Method> getInputFunctions(Class<?> clazz, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> inputFunctions = findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class).stream()
.filter(method -> (method.getParameterTypes()[findAggregationStateParamId(method)] == stateClass))
.collect(toImmutableList());
checkArgument(!inputFunctions.isEmpty(), "Aggregation has no input functions");
return inputFunctions;
}
private static List<TypeSignature> getInputTypesSignatures(Method inputFunction)
{
// FIXME Literal parameters should be part of class annotations.
ImmutableList.Builder<TypeSignature> builder = ImmutableList.builder();
Set<String> literalParameters = getLiteralParameter(inputFunction);
Annotation[][] parameterAnnotations = inputFunction.getParameterAnnotations();
for (Annotation[] annotations : parameterAnnotations) {
for (Annotation annotation : annotations) {
if (annotation instanceof SqlType) {
String typeName = ((SqlType) annotation).value();
builder.add(parseTypeSignature(typeName, literalParameters));
}
}
}
return builder.build();
}
private static Set<Class<?>> getStateClasses(Class<?> clazz)
{
ImmutableSet.Builder<Class<?>> builder = ImmutableSet.builder();
for (Method inputFunction : findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class)) {
checkArgument(inputFunction.getParameterTypes().length > 0, "Input function has no parameters");
Class<?> stateClass = findAggregationStateParamType(inputFunction);
checkArgument(AccumulatorState.class.isAssignableFrom(stateClass), "stateClass is not a subclass of AccumulatorState");
builder.add(stateClass);
}
ImmutableSet<Class<?>> stateClasses = builder.build();
checkArgument(!stateClasses.isEmpty(), "No input functions found");
return stateClasses;
}
private static Class<?> findAggregationStateParamType(Method inputFunction)
{
return inputFunction.getParameterTypes()[findAggregationStateParamId(inputFunction)];
}
public static int findAggregationStateParamId(Method method)
{
return findAggregationStateParamId(method, 0);
}
public static int findAggregationStateParamId(Method method, int id)
{
int currentParamId = 0;
int found = 0;
for (Annotation[] annotations : method.getParameterAnnotations()) {
for (Annotation annotation : annotations) {
if (annotation instanceof AggregationState) {
if (found++ == id) {
return currentParamId;
}
}
}
currentParamId++;
}
// backward compatibility @AggregationState annotation didn't exists before
// some third party aggregates may assume that State will be id-th parameter
return id;
}
private static String getDescription(AnnotatedElement base, AnnotatedElement override)
{
Description description = override.getAnnotation(Description.class);
if (description != null) {
return description.value();
}
description = base.getAnnotation(Description.class);
return (description == null) ? null : description.value();
}
private static Set<String> getLiteralParameter(Method inputFunction)
{
ImmutableSet.Builder<String> literalParametersBuilder = ImmutableSet.builder();
Annotation[] literalParameters = inputFunction.getAnnotations();
for (Annotation annotation : literalParameters) {
if (annotation instanceof LiteralParameters) {
for (String literal : ((LiteralParameters) annotation).value()) {
literalParametersBuilder.add(literal);
}
}
}
return literalParametersBuilder.build();
}
private static List<Method> findPublicStaticMethodsWithAnnotation(Class<?> clazz, Class<?> annotationClass)
{
ImmutableList.Builder<Method> methods = ImmutableList.builder();
for (Method method : clazz.getMethods()) {
for (Annotation annotation : method.getAnnotations()) {
if (annotationClass.isInstance(annotation)) {
checkArgument(Modifier.isStatic(method.getModifiers()) && Modifier.isPublic(method.getModifiers()), "%s annotated with %s must be static and public", method.getName(), annotationClass.getSimpleName());
methods.add(method);
}
}
}
return methods.build();
}
}