/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.aggregation; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.ClassDefinition; import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.bytecode.MethodDefinition; import com.facebook.presto.bytecode.Parameter; import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.SqlAggregationFunction; import com.facebook.presto.operator.aggregation.state.StateCompiler; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.function.AccumulatorStateSerializer; import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.gen.CallSiteBinder; import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.lang.invoke.MethodHandle; import java.util.List; import java.util.Map; import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; import static com.facebook.presto.bytecode.Access.PUBLIC; import static com.facebook.presto.bytecode.Access.STATIC; import static com.facebook.presto.bytecode.Access.a; import static com.facebook.presto.bytecode.CompilerUtils.defineClass; import static com.facebook.presto.bytecode.CompilerUtils.makeClassName; import static com.facebook.presto.bytecode.Parameter.arg; import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.and; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantBoolean; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.not; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.or; import static com.facebook.presto.metadata.Signature.orderableTypeParameter; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata; import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX; import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL; import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL; import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; import static com.facebook.presto.operator.aggregation.state.TwoNullableValueStateMapping.getStateClass; import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN; import static com.facebook.presto.spi.function.OperatorType.LESS_THAN; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.sql.gen.BytecodeUtils.loadConstant; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.facebook.presto.util.Reflection.methodHandle; import static com.google.common.collect.ImmutableList.toImmutableList; public abstract class AbstractMinMaxBy extends SqlAggregationFunction { private final boolean min; protected AbstractMinMaxBy(boolean min) { super((min ? "min" : "max") + "_by", ImmutableList.of(orderableTypeParameter("K"), typeVariable("V")), ImmutableList.of(), parseTypeSignature("V"), ImmutableList.of(parseTypeSignature("V"), parseTypeSignature("K"))); this.min = min; } @Override public InternalAggregationFunction specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); return generateAggregation(valueType, keyType, functionRegistry); } private InternalAggregationFunction generateAggregation(Type valueType, Type keyType, FunctionRegistry functionRegistry) { Class<?> stateClazz = getStateClass(keyType.getJavaType(), valueType.getJavaType()); Map<String, Type> stateFieldTypes = ImmutableMap.of("First", keyType, "Second", valueType); DynamicClassLoader classLoader = new DynamicClassLoader(getClass().getClassLoader()); AccumulatorStateFactory<?> stateFactory = StateCompiler.generateStateFactory(stateClazz, stateFieldTypes, classLoader); AccumulatorStateSerializer<?> stateSerializer = StateCompiler.generateStateSerializer(stateClazz, stateFieldTypes, classLoader); Type intermediateType = stateSerializer.getSerializedType(); List<Type> inputTypes = ImmutableList.of(valueType, keyType); CallSiteBinder binder = new CallSiteBinder(); OperatorType operator = min ? LESS_THAN : GREATER_THAN; MethodHandle compareMethod = functionRegistry.getScalarFunctionImplementation(functionRegistry.resolveOperator(operator, ImmutableList.of(keyType, keyType))).getMethodHandle(); ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), makeClassName("processMaxOrMinBy"), type(Object.class)); definition.declareDefaultConstructor(a(PRIVATE)); generateInputMethod(definition, binder, compareMethod, keyType, valueType, stateClazz); generateCombineMethod(definition, binder, compareMethod, keyType, valueType, stateClazz); generateOutputMethod(definition, binder, valueType, stateClazz); Class<?> generatedClass = defineClass(definition, Object.class, binder.getBindings(), classLoader); MethodHandle inputMethod = methodHandle(generatedClass, "input", stateClazz, Block.class, Block.class, int.class); MethodHandle combineMethod = methodHandle(generatedClass, "combine", stateClazz, stateClazz); MethodHandle outputMethod = methodHandle(generatedClass, "output", stateClazz, BlockBuilder.class); AggregationMetadata metadata = new AggregationMetadata( generateAggregationName(getSignature().getName(), valueType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(valueType, keyType), inputMethod, combineMethod, outputMethod, stateClazz, stateSerializer, stateFactory, valueType); GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(metadata, classLoader); return new InternalAggregationFunction(getSignature().getName(), inputTypes, intermediateType, valueType, true, factory); } private static List<ParameterMetadata> createInputParameterMetadata(Type value, Type key) { return ImmutableList.of(new ParameterMetadata(STATE), new ParameterMetadata(NULLABLE_BLOCK_INPUT_CHANNEL, value), new ParameterMetadata(BLOCK_INPUT_CHANNEL, key), new ParameterMetadata(BLOCK_INDEX)); } private void generateInputMethod(ClassDefinition definition, CallSiteBinder binder, MethodHandle compareMethod, Type keyType, Type valueType, Class<?> stateClass) { Parameter state = arg("state", stateClass); Parameter value = arg("value", Block.class); Parameter key = arg("key", Block.class); Parameter position = arg("position", int.class); MethodDefinition method = definition.declareMethod(a(PUBLIC, STATIC), "input", type(void.class), state, value, key, position); if (keyType.equals(UNKNOWN)) { method.getBody().ret(); return; } SqlTypeBytecodeExpression keySqlType = constantType(binder, keyType); BytecodeBlock ifBlock = new BytecodeBlock() .append(state.invoke("setFirst", void.class, keySqlType.getValue(key, position))) .append(state.invoke("setFirstNull", void.class, constantBoolean(false))) .append(state.invoke("setSecondNull", void.class, value.invoke("isNull", boolean.class, position))); if (!valueType.equals(UNKNOWN)) { SqlTypeBytecodeExpression valueSqlType = constantType(binder, valueType); ifBlock.append(new IfStatement() .condition(value.invoke("isNull", boolean.class, position)) .ifFalse(state.invoke("setSecond", void.class, valueSqlType.getValue(value, position)))); } method.getBody().append(new IfStatement() .condition(or( state.invoke("isFirstNull", boolean.class), and( not(key.invoke("isNull", boolean.class, position)), loadConstant(binder, compareMethod, MethodHandle.class).invoke("invokeExact", boolean.class, keySqlType.getValue(key, position), state.invoke("getFirst", keyType.getJavaType()))))) .ifTrue(ifBlock)) .ret(); } private void generateCombineMethod(ClassDefinition definition, CallSiteBinder binder, MethodHandle compareMethod, Type keyType, Type valueType, Class<?> stateClass) { Parameter state = arg("state", stateClass); Parameter otherState = arg("otherState", stateClass); MethodDefinition method = definition.declareMethod(a(PUBLIC, STATIC), "combine", type(void.class), state, otherState); if (keyType.equals(UNKNOWN)) { method.getBody().ret(); return; } Class<?> keyJavaType = keyType.getJavaType(); BytecodeBlock ifBlock = new BytecodeBlock() .append(state.invoke("setFirst", void.class, otherState.invoke("getFirst", keyJavaType))) .append(state.invoke("setFirstNull", void.class, otherState.invoke("isFirstNull", boolean.class))) .append(state.invoke("setSecondNull", void.class, otherState.invoke("isSecondNull", boolean.class))); if (!valueType.equals(UNKNOWN)) { ifBlock.append(state.invoke("setSecond", void.class, otherState.invoke("getSecond", valueType.getJavaType()))); } method.getBody() .append(new IfStatement() .condition(or( state.invoke("isFirstNull", boolean.class), and( not(otherState.invoke("isFirstNull", boolean.class)), loadConstant(binder, compareMethod, MethodHandle.class).invoke("invokeExact", boolean.class, otherState.invoke("getFirst", keyJavaType), state.invoke("getFirst", keyJavaType))))) .ifTrue(ifBlock)) .ret(); } private void generateOutputMethod(ClassDefinition definition, CallSiteBinder binder, Type valueType, Class<?> stateClass) { Parameter state = arg("state", stateClass); Parameter out = arg("out", BlockBuilder.class); MethodDefinition method = definition.declareMethod(a(PUBLIC, STATIC), "output", type(void.class), state, out); IfStatement ifStatement = new IfStatement() .condition(or(state.invoke("isFirstNull", boolean.class), state.invoke("isSecondNull", boolean.class))) .ifTrue(new BytecodeBlock().append(out.invoke("appendNull", BlockBuilder.class)).pop()); if (!valueType.equals(UNKNOWN)) { ifStatement.ifFalse(constantType(binder, valueType).writeValue(out, state.invoke("getSecond", valueType.getJavaType()))); } method.getBody().append(ifStatement).ret(); } }