/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.DeterminismEvaluator;
import com.facebook.presto.sql.planner.ExpressionInterpreter;
import com.facebook.presto.sql.planner.LiteralInterpreter;
import com.facebook.presto.sql.planner.NoOpSymbolResolver;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.util.maps.IdentityLinkedHashMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static com.facebook.presto.sql.ExpressionUtils.combinePredicates;
import static com.facebook.presto.sql.ExpressionUtils.extractPredicates;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL;
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.facebook.presto.sql.tree.ComparisonExpressionType.IS_DISTINCT_FROM;
import static com.facebook.presto.sql.tree.LogicalBinaryExpression.Type.OR;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;
public class SimplifyExpressions
implements PlanOptimizer
{
private final Metadata metadata;
private final SqlParser sqlParser;
public SimplifyExpressions(Metadata metadata, SqlParser sqlParser)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.sqlParser = requireNonNull(sqlParser, "sqlParser is null");
}
@Override
public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator)
{
requireNonNull(plan, "plan is null");
requireNonNull(session, "session is null");
requireNonNull(types, "types is null");
requireNonNull(symbolAllocator, "symbolAllocator is null");
requireNonNull(idAllocator, "idAllocator is null");
return SimplePlanRewriter.rewriteWith(new Rewriter(metadata, sqlParser, session, types, idAllocator), plan);
}
private static class Rewriter
extends SimplePlanRewriter<Void>
{
private final Metadata metadata;
private final SqlParser sqlParser;
private final Session session;
private final Map<Symbol, Type> types;
private final PlanNodeIdAllocator idAllocator;
public Rewriter(Metadata metadata, SqlParser sqlParser, Session session, Map<Symbol, Type> types, PlanNodeIdAllocator idAllocator)
{
this.metadata = metadata;
this.sqlParser = sqlParser;
this.session = session;
this.types = types;
this.idAllocator = idAllocator;
}
@Override
public PlanNode visitProject(ProjectNode node, RewriteContext<Void> context)
{
PlanNode source = context.rewrite(node.getSource());
return new ProjectNode(node.getId(), source, node.getAssignments().rewrite(this::simplifyExpression));
}
@Override
public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
{
PlanNode source = context.rewrite(node.getSource());
Expression simplified = simplifyExpression(node.getPredicate());
if (simplified.equals(TRUE_LITERAL)) {
return source;
}
// TODO: this needs to check whether the boolean expression coerces to false in a more general way.
// E.g., simplify() not always produces a literal when the expression is constant
else if (simplified.equals(FALSE_LITERAL) || simplified instanceof NullLiteral) {
return new ValuesNode(idAllocator.getNextId(), node.getOutputSymbols(), ImmutableList.of());
}
return new FilterNode(node.getId(), source, simplified);
}
@Override
public PlanNode visitTableScan(TableScanNode node, RewriteContext<Void> context)
{
Expression originalConstraint = null;
if (node.getOriginalConstraint() != null) {
originalConstraint = simplifyExpression(node.getOriginalConstraint());
}
return new TableScanNode(
node.getId(),
node.getTable(),
node.getOutputSymbols(),
node.getAssignments(),
node.getLayout(),
node.getCurrentConstraint(),
originalConstraint);
}
private Expression simplifyExpression(Expression expression)
{
if (expression instanceof SymbolReference) {
return expression;
}
expression = ExpressionTreeRewriter.rewriteWith(new PushDownNegationsExpressionRewriter(), expression);
expression = ExpressionTreeRewriter.rewriteWith(new ExtractCommonPredicatesExpressionRewriter(), expression, NodeContext.ROOT_NODE);
IdentityLinkedHashMap<Expression, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, expression, emptyList() /* parameters already replaced */);
ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes);
return LiteralInterpreter.toExpression(interpreter.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(expression));
}
}
private static class PushDownNegationsExpressionRewriter
extends ExpressionRewriter<Void>
{
@Override
public Expression rewriteNotExpression(NotExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
if (node.getValue() instanceof LogicalBinaryExpression) {
LogicalBinaryExpression child = (LogicalBinaryExpression) node.getValue();
List<Expression> predicates = extractPredicates(child);
List<Expression> negatedPredicates = predicates.stream()
.map(predicate -> treeRewriter.rewrite((Expression) new NotExpression(predicate), context))
.collect(toImmutableList());
return combinePredicates(child.getType().flip(), negatedPredicates);
}
else if (node.getValue() instanceof ComparisonExpression && ((ComparisonExpression) node.getValue()).getType() != IS_DISTINCT_FROM) {
ComparisonExpression child = (ComparisonExpression) node.getValue();
return new ComparisonExpression(
child.getType().negate(),
treeRewriter.rewrite(child.getLeft(), context),
treeRewriter.rewrite(child.getRight(), context));
}
else if (node.getValue() instanceof NotExpression) {
NotExpression child = (NotExpression) node.getValue();
return treeRewriter.rewrite(child.getValue(), context);
}
return new NotExpression(treeRewriter.rewrite(node.getValue(), context));
}
}
private enum NodeContext
{
ROOT_NODE,
NOT_ROOT_NODE;
boolean isRootNode()
{
return this == ROOT_NODE;
}
}
private static class ExtractCommonPredicatesExpressionRewriter
extends ExpressionRewriter<NodeContext>
{
@Override
public Expression rewriteExpression(Expression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter)
{
if (context.isRootNode()) {
return treeRewriter.rewrite(node, NodeContext.NOT_ROOT_NODE);
}
return null;
}
@Override
public Expression rewriteLogicalBinaryExpression(LogicalBinaryExpression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter)
{
Expression expression = combinePredicates(
node.getType(),
extractPredicates(node.getType(), node).stream()
.map(subExpression -> treeRewriter.rewrite(subExpression, NodeContext.NOT_ROOT_NODE))
.collect(toImmutableList()));
if (!(expression instanceof LogicalBinaryExpression)) {
return expression;
}
Expression simplified = extractCommonPredicates((LogicalBinaryExpression) expression);
// Prefer AND LogicalBinaryExpression at the root if possible
if (context.isRootNode() && simplified instanceof LogicalBinaryExpression && ((LogicalBinaryExpression) simplified).getType() == OR) {
return distributeIfPossible((LogicalBinaryExpression) simplified);
}
return simplified;
}
private static Expression extractCommonPredicates(LogicalBinaryExpression node)
{
List<List<Expression>> subPredicates = getSubPredicates(node);
Set<Expression> commonPredicates = ImmutableSet.copyOf(subPredicates.stream()
.map(ExtractCommonPredicatesExpressionRewriter::filterDeterministicPredicates)
.reduce(Sets::intersection)
.orElse(emptySet()));
List<List<Expression>> uncorrelatedSubPredicates = subPredicates.stream()
.map(predicateList -> removeAll(predicateList, commonPredicates))
.collect(toImmutableList());
LogicalBinaryExpression.Type flippedNodeType = node.getType().flip();
List<Expression> uncorrelatedPredicates = uncorrelatedSubPredicates.stream()
.map(predicate -> combinePredicates(flippedNodeType, predicate))
.collect(toImmutableList());
Expression combinedUncorrelatedPredicates = combinePredicates(node.getType(), uncorrelatedPredicates);
return combinePredicates(flippedNodeType, ImmutableList.<Expression>builder()
.addAll(commonPredicates)
.add(combinedUncorrelatedPredicates)
.build());
}
private static List<List<Expression>> getSubPredicates(LogicalBinaryExpression expression)
{
return extractPredicates(expression.getType(), expression).stream()
.map(predicate -> predicate instanceof LogicalBinaryExpression ?
extractPredicates((LogicalBinaryExpression) predicate) : ImmutableList.of(predicate))
.collect(toImmutableList());
}
/**
* Applies the boolean distributive property.
*
* For example:
* ( A & B ) | ( C & D ) => ( A | C ) & ( A | D ) & ( B | C ) & ( B | D )
*
* Returns the original expression if the expression is non-deterministic or if the distribution will
* expand the expression by too much.
*/
private static Expression distributeIfPossible(LogicalBinaryExpression expression)
{
if (!DeterminismEvaluator.isDeterministic(expression)) {
// Do not distribute boolean expressions if there are any non-deterministic elements
// TODO: This can be optimized further if non-deterministic elements are not repeated
return expression;
}
List<Set<Expression>> subPredicates = getSubPredicates(expression).stream()
.map(ImmutableSet::copyOf)
.collect(toList());
int originalBaseExpressions = subPredicates.stream()
.mapToInt(Set::size)
.sum();
int newBaseExpressions;
try {
newBaseExpressions = Math.multiplyExact(subPredicates.stream()
.mapToInt(Set::size)
.reduce(Math::multiplyExact)
.getAsInt(), subPredicates.size());
}
catch (ArithmeticException e) {
// Integer overflow from multiplication means there are too many expressions
return expression;
}
if (newBaseExpressions > originalBaseExpressions * 2) {
// Do not distribute boolean expressions if it would create 2x more base expressions
// (e.g. A, B, C, D from the above example). This is just an arbitrary heuristic to
// avoid cross product expression explosion.
return expression;
}
Set<List<Expression>> crossProduct = Sets.cartesianProduct(subPredicates);
return combinePredicates(
expression.getType().flip(),
crossProduct.stream()
.map(expressions -> combinePredicates(expression.getType(), expressions))
.collect(toImmutableList()));
}
private static Set<Expression> filterDeterministicPredicates(List<Expression> predicates)
{
return predicates.stream()
.filter(DeterminismEvaluator::isDeterministic)
.collect(toSet());
}
private static <T> List<T> removeAll(Collection<T> collection, Collection<T> elementsToRemove)
{
return collection.stream()
.filter(element -> !elementsToRemove.contains(element))
.collect(toImmutableList());
}
}
}