/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;
import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.metadata.SqlAggregationFunction;
import com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata;
import com.facebook.presto.operator.aggregation.state.StateCompiler;
import com.facebook.presto.spi.function.AccumulatorStateFactory;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.spi.type.TypeSignature;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import javax.annotation.Nullable;
import java.lang.annotation.Annotation;
import java.lang.invoke.MethodHandle;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables;
import static com.facebook.presto.operator.aggregation.AggregationCompiler.isParameterBlock;
import static com.facebook.presto.operator.aggregation.AggregationCompiler.isParameterNullable;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.fromSqlType;
import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.invoke.MethodHandles.lookup;
import static java.util.Objects.requireNonNull;
public class BindableAggregationFunction
extends SqlAggregationFunction
{
private final String description;
private final boolean decomposable;
private final Class<?> definitionClass;
private final Class<?> stateClass;
private final Method inputFunction;
private final Method outputFunction;
public BindableAggregationFunction(Signature signature,
String description,
boolean decomposable,
Class<?> definitionClass,
Class<?> stateClass,
Method inputFunction,
Method outputFunction)
{
super(signature);
this.description = description;
this.decomposable = decomposable;
this.definitionClass = definitionClass;
this.stateClass = stateClass;
this.inputFunction = inputFunction;
this.outputFunction = outputFunction;
}
@Override
public String getDescription()
{
return description;
}
@Override
public InternalAggregationFunction specialize(BoundVariables variables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry)
{
// bind variables
Signature boundSignature = applyBoundVariables(getSignature(), variables, arity);
List<Type> inputTypes = boundSignature.getArgumentTypes().stream().map(x -> typeManager.getType(x)).collect(toImmutableList());
Type outputType = typeManager.getType(boundSignature.getReturnType());
AggregationFunction aggregationAnnotation = definitionClass.getAnnotation(AggregationFunction.class);
requireNonNull(aggregationAnnotation, "aggregationAnnotation is null");
DynamicClassLoader classLoader = new DynamicClassLoader(definitionClass.getClassLoader(), getClass().getClassLoader());
AggregationMetadata metadata;
AccumulatorStateSerializer<?> stateSerializer = StateCompiler.generateStateSerializer(stateClass, classLoader);
Type intermediateType = stateSerializer.getSerializedType();
Method combineFunction = AggregationCompiler.getCombineFunction(definitionClass, stateClass);
AccumulatorStateFactory<?> stateFactory = StateCompiler.generateStateFactory(stateClass, classLoader);
try {
MethodHandle inputHandle = lookup().unreflect(inputFunction);
MethodHandle combineHandle = lookup().unreflect(combineFunction);
MethodHandle outputHandle = outputFunction == null ? null : lookup().unreflect(outputFunction);
metadata = new AggregationMetadata(
generateAggregationName(getSignature().getName(), outputType.getTypeSignature(), signaturesFromTypes(inputTypes)),
getParameterMetadata(inputFunction, inputTypes),
inputHandle,
combineHandle,
outputHandle,
stateClass,
stateSerializer,
stateFactory,
outputType);
}
catch (IllegalAccessException e) {
throw Throwables.propagate(e);
}
AccumulatorFactoryBinder factory = new LazyAccumulatorFactoryBinder(metadata, classLoader);
return new InternalAggregationFunction(getSignature().getName(),
inputTypes,
intermediateType,
outputType,
decomposable,
factory);
}
public InternalAggregationFunction specialize(BoundVariables variables, int arity, TypeManager typeManager)
{
return specialize(variables, arity, typeManager, null);
}
private static List<TypeSignature> signaturesFromTypes(List<Type> types)
{
return types
.stream()
.map(x -> x.getTypeSignature())
.collect(toImmutableList());
}
private static List<ParameterMetadata> getParameterMetadata(@Nullable Method method, List<Type> inputTypes)
{
if (method == null) {
return null;
}
ImmutableList.Builder<ParameterMetadata> builder = ImmutableList.builder();
Annotation[][] annotations = method.getParameterAnnotations();
String methodName = method.getDeclaringClass() + "." + method.getName();
checkArgument(annotations.length > 0, "At least @AggregationState argument is required for each of aggregation functions.");
int inputId = 0;
int i = 0;
if (annotations[0].length == 0) {
// Backward compatibility - first argument without annotations is interpreted as State argument
builder.add(new ParameterMetadata(STATE));
i++;
}
for (; i < annotations.length; i++) {
Annotation baseTypeAnnotation = baseTypeAnnotation(annotations[i], methodName);
if (baseTypeAnnotation instanceof SqlType) {
builder.add(fromSqlType(inputTypes.get(i - 1), isParameterBlock(annotations[i]), isParameterNullable(annotations[i]), methodName));
}
else if (baseTypeAnnotation instanceof BlockIndex) {
builder.add(new ParameterMetadata(BLOCK_INDEX));
}
else if (baseTypeAnnotation instanceof AggregationState) {
builder.add(new ParameterMetadata(STATE));
}
else {
throw new IllegalArgumentException("Unsupported annotation: " + annotations[i]);
}
}
return builder.build();
}
private static Annotation baseTypeAnnotation(Annotation[] annotations, String methodName)
{
List<Annotation> baseTypes = Arrays.asList(annotations).stream()
.filter(annotation -> annotation instanceof SqlType || annotation instanceof BlockIndex || annotation instanceof AggregationState)
.collect(toImmutableList());
checkArgument(baseTypes.size() == 1, "Parameter of %s must have exactly one of @SqlType, @BlockIndex", methodName);
boolean nullable = isParameterNullable(annotations);
boolean isBlock = isParameterBlock(annotations);
Annotation annotation = baseTypes.get(0);
checkArgument((!isBlock && !nullable) || (annotation instanceof SqlType),
"%s contains a parameter with @BlockPosition and/or @NullablePosition that is not @SqlType", methodName);
return annotation;
}
}