/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.relational.optimizer; import com.facebook.presto.Session; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.ConstantExpression; import com.facebook.presto.sql.relational.InputReferenceExpression; import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.RowExpressionVisitor; import com.facebook.presto.sql.relational.VariableReferenceExpression; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import java.lang.invoke.MethodHandle; import java.util.ArrayList; import java.util.List; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.constantNull; import static com.facebook.presto.sql.relational.Signatures.BIND; import static com.facebook.presto.sql.relational.Signatures.CAST; import static com.facebook.presto.sql.relational.Signatures.COALESCE; import static com.facebook.presto.sql.relational.Signatures.DEREFERENCE; import static com.facebook.presto.sql.relational.Signatures.IF; import static com.facebook.presto.sql.relational.Signatures.IN; import static com.facebook.presto.sql.relational.Signatures.IS_NULL; import static com.facebook.presto.sql.relational.Signatures.NULL_IF; import static com.facebook.presto.sql.relational.Signatures.ROW_CONSTRUCTOR; import static com.facebook.presto.sql.relational.Signatures.SWITCH; import static com.facebook.presto.sql.relational.Signatures.TRY; import static com.facebook.presto.sql.relational.Signatures.TRY_CAST; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Predicates.instanceOf; import static com.google.common.collect.ImmutableList.toImmutableList; public class ExpressionOptimizer { private final FunctionRegistry registry; private final TypeManager typeManager; private final ConnectorSession session; public ExpressionOptimizer(FunctionRegistry registry, TypeManager typeManager, Session session) { this.registry = registry; this.typeManager = typeManager; this.session = session.toConnectorSession(); } public RowExpression optimize(RowExpression expression) { return expression.accept(new Visitor(), null); } private class Visitor implements RowExpressionVisitor<Void, RowExpression> { @Override public RowExpression visitInputReference(InputReferenceExpression reference, Void context) { return reference; } @Override public RowExpression visitConstant(ConstantExpression literal, Void context) { return literal; } @Override public RowExpression visitCall(CallExpression call, Void context) { ScalarFunctionImplementation function; Signature signature = call.getSignature(); if (signature.getName().equals(CAST)) { Signature functionSignature = registry.getCoercion(call.getArguments().get(0).getType(), call.getType()); function = registry.getScalarFunctionImplementation(functionSignature); } else { switch (signature.getName()) { // TODO: optimize these special forms case IF: { checkState(call.getArguments().size() == 3, "IF function should have 3 arguments. Get " + call.getArguments().size()); RowExpression optimizedOperand = call.getArguments().get(0).accept(this, context); if (optimizedOperand instanceof ConstantExpression) { ConstantExpression constantOperand = (ConstantExpression) optimizedOperand; checkState(constantOperand.getType().equals(BOOLEAN), "Operand of IF function should be BOOLEAN type. Get type " + constantOperand.getType().getDisplayName()); if (Boolean.TRUE.equals(constantOperand.getValue())) { return call.getArguments().get(1).accept(this, context); } // FALSE and NULL else { return call.getArguments().get(2).accept(this, context); } } List<RowExpression> arguments = call.getArguments().stream() .map(argument -> argument.accept(this, null)) .collect(toImmutableList()); return call(signature, call.getType(), arguments); } case TRY: { checkState(call.getArguments().size() == 1, "try call expressions must have a single argument"); if (!(Iterables.getOnlyElement(call.getArguments()) instanceof CallExpression)) { return Iterables.getOnlyElement(call.getArguments()).accept(this, null); } List<RowExpression> arguments = call.getArguments().stream() .map(argument -> argument.accept(this, null)) .collect(toImmutableList()); return call(signature, call.getType(), arguments); } case BIND: { checkState(call.getArguments().size() == 2, BIND + " function should have 2 arguments. Got " + call.getArguments().size()); RowExpression optimizedValue = call.getArguments().get(0).accept(this, context); RowExpression optimizedFunction = call.getArguments().get(1).accept(this, context); if (optimizedValue instanceof ConstantExpression && optimizedFunction instanceof ConstantExpression) { // Here, optimizedValue and optimizedFunction should be merged together into a new ConstantExpression. // It's not implemented because it would be dead code anyways because visitLambda does not produce ConstantExpression. throw new UnsupportedOperationException(); } return call(signature, call.getType(), ImmutableList.of(optimizedValue, optimizedFunction)); } case NULL_IF: case SWITCH: case "WHEN": case TRY_CAST: case IS_NULL: case COALESCE: case "AND": case "OR": case IN: case DEREFERENCE: case ROW_CONSTRUCTOR: { List<RowExpression> arguments = call.getArguments().stream() .map(argument -> argument.accept(this, null)) .collect(toImmutableList()); return call(signature, call.getType(), arguments); } default: function = registry.getScalarFunctionImplementation(signature); } } List<RowExpression> arguments = call.getArguments().stream() .map(argument -> argument.accept(this, context)) .collect(toImmutableList()); // TODO: optimize function calls with lambda arguments. For example, apply(x -> x + 2, 1) if (Iterables.all(arguments, instanceOf(ConstantExpression.class)) && function.isDeterministic()) { MethodHandle method = function.getMethodHandle(); if (method.type().parameterCount() > 0 && method.type().parameterType(0) == ConnectorSession.class) { method = method.bindTo(session); } int index = 0; List<Object> constantArguments = new ArrayList<>(); for (RowExpression argument : arguments) { Object value = ((ConstantExpression) argument).getValue(); // if any argument is null, return null if (value == null && !function.getNullableArguments().get(index)) { return constantNull(call.getType()); } constantArguments.add(value); index++; } try { return constant(method.invokeWithArguments(constantArguments), call.getType()); } catch (Throwable e) { if (e instanceof InterruptedException) { Thread.currentThread().interrupt(); } // Do nothing. As a result, this specific tree will be left untouched. But irrelevant expressions will continue to get evaluated and optimized. } } return call(signature, typeManager.getType(signature.getReturnType()), arguments); } @Override public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context) { return new LambdaDefinitionExpression(lambda.getArgumentTypes(), lambda.getArguments(), lambda.getBody().accept(this, context)); } @Override public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context) { return reference; } } }