/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.scalar; import com.facebook.presto.annotation.UsedByGeneratedCode; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.ClassDefinition; import com.facebook.presto.bytecode.CompilerUtils; import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.bytecode.MethodDefinition; import com.facebook.presto.bytecode.Parameter; import com.facebook.presto.bytecode.Scope; import com.facebook.presto.bytecode.Variable; import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.gen.CallSiteBinder; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; import java.util.Collections; import java.util.List; import java.util.stream.IntStream; import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; import static com.facebook.presto.bytecode.Access.PUBLIC; import static com.facebook.presto.bytecode.Access.STATIC; import static com.facebook.presto.bytecode.Access.a; import static com.facebook.presto.bytecode.CompilerUtils.defineClass; import static com.facebook.presto.bytecode.Parameter.arg; import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.facebook.presto.metadata.Signature.internalOperator; import static com.facebook.presto.metadata.Signature.orderableTypeParameter; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.sql.gen.BytecodeUtils.invoke; import static com.facebook.presto.util.Reflection.methodHandle; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; public abstract class AbstractGreatestLeast extends SqlScalarFunction { private static final MethodHandle CHECK_NOT_NAN = methodHandle(AbstractGreatestLeast.class, "checkNotNaN", String.class, double.class); private final OperatorType operatorType; protected AbstractGreatestLeast(String name, OperatorType operatorType) { super(new Signature( name, FunctionKind.SCALAR, ImmutableList.of(orderableTypeParameter("E")), ImmutableList.of(), parseTypeSignature("E"), ImmutableList.of(parseTypeSignature("E")), true)); this.operatorType = requireNonNull(operatorType, "operatorType is null"); } @Override public boolean isHidden() { return false; } @Override public boolean isDeterministic() { return true; } @Override public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { Type type = boundVariables.getTypeVariable("E"); checkArgument(type.isOrderable(), "Type must be orderable"); MethodHandle compareMethod = functionRegistry.getScalarFunctionImplementation(internalOperator(operatorType, BOOLEAN, ImmutableList.of(type, type))).getMethodHandle(); List<Class<?>> javaTypes = IntStream.range(0, arity) .mapToObj(i -> type.getJavaType()) .collect(toImmutableList()); Class<?> clazz = generate(javaTypes, type, compareMethod); MethodHandle methodHandle = methodHandle(clazz, getSignature().getName(), javaTypes.toArray(new Class<?>[javaTypes.size()])); List<Boolean> nullableParameters = ImmutableList.copyOf(Collections.nCopies(javaTypes.size(), false)); return new ScalarFunctionImplementation(false, nullableParameters, methodHandle, isDeterministic()); } @UsedByGeneratedCode public static void checkNotNaN(String name, double value) { if (Double.isNaN(value)) { throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Invalid argument to %s(): NaN", name)); } } private Class<?> generate(List<Class<?>> javaTypes, Type type, MethodHandle compareMethod) { String javaTypeName = javaTypes.stream() .map(Class::getSimpleName) .collect(joining()); ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), CompilerUtils.makeClassName(javaTypeName + "$" + getSignature().getName()), type(Object.class)); definition.declareDefaultConstructor(a(PRIVATE)); List<Parameter> parameters = IntStream.range(0, javaTypes.size()) .mapToObj(i -> arg("arg" + i, javaTypes.get(i))) .collect(toImmutableList()); MethodDefinition method = definition.declareMethod( a(PUBLIC, STATIC), getSignature().getName(), type(javaTypes.get(0)), parameters); Scope scope = method.getScope(); BytecodeBlock body = method.getBody(); CallSiteBinder binder = new CallSiteBinder(); if (type.getTypeSignature().getBase().equals(StandardTypes.DOUBLE)) { for (Parameter parameter : parameters) { body.append(parameter); body.append(invoke(binder.bind(CHECK_NOT_NAN.bindTo(getSignature().getName())), "checkNotNaN")); } } Variable value = scope.declareVariable(javaTypes.get(0), "value"); body.append(value.set(parameters.get(0))); for (int i = 1; i < javaTypes.size(); i++) { body.append(new IfStatement() .condition(new BytecodeBlock() .append(parameters.get(i)) .append(value) .append(invoke(binder.bind(compareMethod), "compare"))) .ifTrue(value.set(parameters.get(i)))); } body.append(value.ret()); return defineClass(definition, Object.class, binder.getBindings(), new DynamicClassLoader(getClass().getClassLoader())); } }