/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.function.AccumulatorStateFactory;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.type.Type;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Set;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;
public class AggregationMetadata
{
public static final Set<Class<?>> SUPPORTED_PARAMETER_TYPES = ImmutableSet.of(Block.class, long.class, double.class, boolean.class, Slice.class);
private final String name;
private final List<ParameterMetadata> inputMetadata;
private final MethodHandle inputFunction;
private final MethodHandle combineFunction;
private final MethodHandle outputFunction;
private final AccumulatorStateSerializer<?> stateSerializer;
private final AccumulatorStateFactory<?> stateFactory;
private final Type outputType;
public AggregationMetadata(
String name,
List<ParameterMetadata> inputMetadata,
MethodHandle inputFunction,
MethodHandle combineFunction,
MethodHandle outputFunction,
Class<?> stateInterface,
AccumulatorStateSerializer<?> stateSerializer,
AccumulatorStateFactory<?> stateFactory,
Type outputType)
{
this.outputType = requireNonNull(outputType);
this.inputMetadata = ImmutableList.copyOf(requireNonNull(inputMetadata, "inputMetadata is null"));
this.name = requireNonNull(name, "name is null");
this.inputFunction = requireNonNull(inputFunction, "inputFunction is null");
this.combineFunction = requireNonNull(combineFunction, "combineFunction is null");
this.outputFunction = requireNonNull(outputFunction, "outputFunction is null");
this.stateSerializer = requireNonNull(stateSerializer, "stateSerializer is null");
this.stateFactory = requireNonNull(stateFactory, "stateFactory is null");
verifyInputFunctionSignature(inputFunction, inputMetadata, stateInterface);
verifyCombineFunction(combineFunction, stateInterface);
verifyExactOutputFunction(outputFunction, stateInterface);
}
public Type getOutputType()
{
return outputType;
}
public List<ParameterMetadata> getInputMetadata()
{
return inputMetadata;
}
public String getName()
{
return name;
}
public MethodHandle getInputFunction()
{
return inputFunction;
}
public MethodHandle getCombineFunction()
{
return combineFunction;
}
public MethodHandle getOutputFunction()
{
return outputFunction;
}
public AccumulatorStateSerializer<?> getStateSerializer()
{
return stateSerializer;
}
public AccumulatorStateFactory<?> getStateFactory()
{
return stateFactory;
}
private static void verifyInputFunctionSignature(MethodHandle method, List<ParameterMetadata> parameterMetadatas, Class<?> stateInterface)
{
Class<?>[] parameters = method.type().parameterArray();
checkArgument(stateInterface == parameters[0], "First argument of aggregation input function must be %s", stateInterface.getSimpleName());
checkArgument(parameters.length > 0, "Aggregation input function must have at least one parameter");
checkArgument(parameterMetadatas.get(0).getParameterType() == STATE, "First parameter must be state");
for (int i = 1; i < parameters.length; i++) {
ParameterMetadata metadata = parameterMetadatas.get(i);
switch (metadata.getParameterType()) {
case BLOCK_INPUT_CHANNEL:
case NULLABLE_BLOCK_INPUT_CHANNEL:
checkArgument(parameters[i] == Block.class, "Parameter must be Block if it has @BlockPosition");
break;
case INPUT_CHANNEL:
checkArgument(SUPPORTED_PARAMETER_TYPES.contains(parameters[i]), "Unsupported type: %s", parameters[i].getSimpleName());
checkArgument(parameters[i] == metadata.getSqlType().getJavaType(),
"Expected method %s parameter %s type to be %s (%s)",
method,
i,
metadata.getSqlType().getJavaType().getName(),
metadata.getSqlType());
break;
case BLOCK_INDEX:
checkArgument(parameters[i] == int.class, "Block index parameter must be an int");
break;
default:
throw new IllegalArgumentException("Unsupported parameter: " + metadata.getParameterType());
}
}
}
private static void verifyCombineFunction(MethodHandle method, Class<?> stateInterface)
{
Class<?>[] parameterTypes = method.type().parameterArray();
checkArgument(parameterTypes.length == 2 && parameterTypes[0] == stateInterface && parameterTypes[1] == stateInterface, "Combine function must have the signature (%s, %s)", stateInterface.getSimpleName(), stateInterface.getSimpleName());
}
private static void verifyExactOutputFunction(MethodHandle method, Class<?> stateInterface)
{
if (method == null) {
return;
}
Class<?>[] parameterTypes = method.type().parameterArray();
checkArgument(parameterTypes.length == 2 && parameterTypes[0] == stateInterface && parameterTypes[1] == BlockBuilder.class, "Output function must have the signature (%s, BlockBuilder)", stateInterface.getSimpleName());
}
public static int countInputChannels(List<ParameterMetadata> metadatas)
{
int parameters = 0;
for (ParameterMetadata metadata : metadatas) {
if (metadata.getParameterType() == INPUT_CHANNEL ||
metadata.getParameterType() == BLOCK_INPUT_CHANNEL ||
metadata.getParameterType() == NULLABLE_BLOCK_INPUT_CHANNEL) {
parameters++;
}
}
return parameters;
}
public static class ParameterMetadata
{
private final ParameterType parameterType;
private final Type sqlType;
public ParameterMetadata(ParameterType parameterType)
{
this(parameterType, null);
}
public ParameterMetadata(ParameterType parameterType, Type sqlType)
{
checkArgument((sqlType == null) == (parameterType == BLOCK_INDEX || parameterType == STATE),
"sqlType must be provided only for input channels");
this.parameterType = parameterType;
this.sqlType = sqlType;
}
public static ParameterMetadata fromSqlType(Type sqlType, boolean isBlock, boolean isNullable, String methodName)
{
if (isBlock) {
if (isNullable) {
return new ParameterMetadata(NULLABLE_BLOCK_INPUT_CHANNEL, sqlType);
}
else {
return new ParameterMetadata(BLOCK_INPUT_CHANNEL, sqlType);
}
}
else {
if (isNullable) {
throw new IllegalArgumentException(methodName + " contains a parameter with @NullablePosition that is not @BlockPosition");
}
else {
return new ParameterMetadata(INPUT_CHANNEL, sqlType);
}
}
}
public static ParameterMetadata forBlockIndexParameter()
{
return new ParameterMetadata(BLOCK_INDEX);
}
public static ParameterMetadata forStateParameter()
{
return new ParameterMetadata(STATE);
}
public ParameterType getParameterType()
{
return parameterType;
}
public Type getSqlType()
{
return sqlType;
}
public enum ParameterType
{
INPUT_CHANNEL,
BLOCK_INPUT_CHANNEL,
NULLABLE_BLOCK_INPUT_CHANNEL,
BLOCK_INDEX,
STATE
}
}
}