/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.DistinctLimitNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanVisitor; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.SortNode; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.TopNNode; import com.facebook.presto.sql.planner.plan.UnionNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.ComparisonExpressionType; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Function; import java.util.function.Predicate; import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; import static com.facebook.presto.sql.ExpressionUtils.expressionOrNullSymbols; import static com.facebook.presto.sql.ExpressionUtils.extractConjuncts; import static com.facebook.presto.sql.ExpressionUtils.stripNonDeterministicConjuncts; import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.google.common.base.Predicates.in; import static com.google.common.collect.ImmutableList.toImmutableList; /** * Computes the effective predicate at the top of the specified PlanNode * <p> * Note: non-deterministic predicates can not be pulled up (so they will be ignored) */ public class EffectivePredicateExtractor extends PlanVisitor<Void, Expression> { public static Expression extract(PlanNode node, Map<Symbol, Type> symbolTypes) { return node.accept(new EffectivePredicateExtractor(symbolTypes), null); } private static final Predicate<Map.Entry<Symbol, ? extends Expression>> SYMBOL_MATCHES_EXPRESSION = entry -> entry.getValue().equals(entry.getKey().toSymbolReference()); private static final Function<Map.Entry<Symbol, ? extends Expression>, Expression> ENTRY_TO_EQUALITY = entry -> { SymbolReference reference = entry.getKey().toSymbolReference(); Expression expression = entry.getValue(); // TODO: switch this to 'IS NOT DISTINCT FROM' syntax when EqualityInference properly supports it return new ComparisonExpression(ComparisonExpressionType.EQUAL, reference, expression); }; private final Map<Symbol, Type> symbolTypes; public EffectivePredicateExtractor(Map<Symbol, Type> symbolTypes) { this.symbolTypes = symbolTypes; } @Override protected Expression visitPlan(PlanNode node, Void context) { return TRUE_LITERAL; } @Override public Expression visitAggregation(AggregationNode node, Void context) { // GROUP BY () always produces a group, regardless of whether there's any // input (unlike the case where there are group by keys, which produce // no output if there's no input). // Therefore, we can't say anything about the effective predicate of the // output of such an aggregation. if (node.getGroupingKeys().isEmpty()) { return TRUE_LITERAL; } Expression underlyingPredicate = node.getSource().accept(this, context); return pullExpressionThroughSymbols(underlyingPredicate, node.getGroupingKeys()); } @Override public Expression visitFilter(FilterNode node, Void context) { Expression underlyingPredicate = node.getSource().accept(this, context); Expression predicate = node.getPredicate(); // Remove non-deterministic conjuncts predicate = stripNonDeterministicConjuncts(predicate); return combineConjuncts(predicate, underlyingPredicate); } @Override public Expression visitExchange(ExchangeNode node, Void context) { return deriveCommonPredicates(node, source -> { Map<Symbol, SymbolReference> mappings = new HashMap<>(); for (int i = 0; i < node.getInputs().get(source).size(); i++) { mappings.put( node.getOutputSymbols().get(i), node.getInputs().get(source).get(i).toSymbolReference()); } return mappings.entrySet(); }); } @Override public Expression visitProject(ProjectNode node, Void context) { // TODO: add simple algebraic solver for projection translation (right now only considers identity projections) Expression underlyingPredicate = node.getSource().accept(this, context); List<Expression> projectionEqualities = node.getAssignments().entrySet().stream() .filter(SYMBOL_MATCHES_EXPRESSION.negate()) .map(ENTRY_TO_EQUALITY) .collect(toImmutableList()); return pullExpressionThroughSymbols(combineConjuncts( ImmutableList.<Expression>builder() .addAll(projectionEqualities) .add(underlyingPredicate) .build()), node.getOutputSymbols()); } @Override public Expression visitTopN(TopNNode node, Void context) { return node.getSource().accept(this, context); } @Override public Expression visitLimit(LimitNode node, Void context) { return node.getSource().accept(this, context); } @Override public Expression visitDistinctLimit(DistinctLimitNode node, Void context) { return node.getSource().accept(this, context); } @Override public Expression visitTableScan(TableScanNode node, Void context) { Map<ColumnHandle, Symbol> assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); return DomainTranslator.toPredicate(spanTupleDomain(node.getCurrentConstraint()).transform(assignments::get)); } private static TupleDomain<ColumnHandle> spanTupleDomain(TupleDomain<ColumnHandle> tupleDomain) { if (tupleDomain.isNone()) { return tupleDomain; } // Simplify domains if they get too complex Map<ColumnHandle, Domain> spannedDomains = Maps.transformValues(tupleDomain.getDomains().get(), DomainUtils::simplifyDomain); return TupleDomain.withColumnDomains(spannedDomains); } @Override public Expression visitSort(SortNode node, Void context) { return node.getSource().accept(this, context); } @Override public Expression visitWindow(WindowNode node, Void context) { return node.getSource().accept(this, context); } @Override public Expression visitUnion(UnionNode node, Void context) { return deriveCommonPredicates(node, source -> node.outputSymbolMap(source).entries()); } @Override public Expression visitJoin(JoinNode node, Void context) { Expression leftPredicate = node.getLeft().accept(this, context); Expression rightPredicate = node.getRight().accept(this, context); List<Expression> joinConjuncts = new ArrayList<>(); for (JoinNode.EquiJoinClause clause : node.getCriteria()) { joinConjuncts.add(new ComparisonExpression(ComparisonExpressionType.EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference())); } switch (node.getType()) { case INNER: return combineConjuncts(ImmutableList.<Expression>builder() .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) .add(pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())) .addAll(pullExpressionsThroughSymbols(joinConjuncts, node.getOutputSymbols())) .build()); case LEFT: return combineConjuncts(ImmutableList.<Expression>builder() .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .build()); case RIGHT: return combineConjuncts(ImmutableList.<Expression>builder() .add(pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) .build()); case FULL: return combineConjuncts(ImmutableList.<Expression>builder() .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains, node.getRight().getOutputSymbols()::contains)) .build()); default: throw new UnsupportedOperationException("Unknown join type: " + node.getType()); } } private static Iterable<Expression> pullNullableConjunctsThroughOuterJoin(List<Expression> conjuncts, Collection<Symbol> outputSymbols, Predicate<Symbol>... nullSymbolScopes) { // Conjuncts without any symbol dependencies cannot be applied to the effective predicate (e.g. FALSE literal) return conjuncts.stream() .map(expression -> pullExpressionThroughSymbols(expression, outputSymbols)) .map(expression -> DependencyExtractor.extractAll(expression).isEmpty() ? TRUE_LITERAL : expression) .map(expressionOrNullSymbols(nullSymbolScopes)) .collect(toImmutableList()); } @Override public Expression visitSemiJoin(SemiJoinNode node, Void context) { // Filtering source does not change the effective predicate over the output symbols return node.getSource().accept(this, context); } private Expression deriveCommonPredicates(PlanNode node, Function<Integer, Collection<Map.Entry<Symbol, SymbolReference>>> mapping) { // Find the predicates that can be pulled up from each source List<Set<Expression>> sourceOutputConjuncts = new ArrayList<>(); for (int i = 0; i < node.getSources().size(); i++) { Expression underlyingPredicate = node.getSources().get(i).accept(this, null); List<Expression> equalities = mapping.apply(i).stream() .filter(SYMBOL_MATCHES_EXPRESSION.negate()) .map(ENTRY_TO_EQUALITY) .collect(toImmutableList()); sourceOutputConjuncts.add(ImmutableSet.copyOf(extractConjuncts(pullExpressionThroughSymbols(combineConjuncts( ImmutableList.<Expression>builder() .addAll(equalities) .add(underlyingPredicate) .build()), node.getOutputSymbols())))); } // Find the intersection of predicates across all sources // TODO: use a more precise way to determine overlapping conjuncts (e.g. commutative predicates) Iterator<Set<Expression>> iterator = sourceOutputConjuncts.iterator(); Set<Expression> potentialOutputConjuncts = iterator.next(); while (iterator.hasNext()) { potentialOutputConjuncts = Sets.intersection(potentialOutputConjuncts, iterator.next()); } return combineConjuncts(potentialOutputConjuncts); } private static List<Expression> pullExpressionsThroughSymbols(List<Expression> expressions, Collection<Symbol> symbols) { return expressions.stream() .map(expression -> pullExpressionThroughSymbols(expression, symbols)) .collect(toImmutableList()); } private static Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> symbols) { EqualityInference equalityInference = createEqualityInference(expression); ImmutableList.Builder<Expression> effectiveConjuncts = ImmutableList.builder(); for (Expression conjunct : EqualityInference.nonInferrableConjuncts(expression)) { if (DeterminismEvaluator.isDeterministic(conjunct)) { Expression rewritten = equalityInference.rewriteExpression(conjunct, in(symbols)); if (rewritten != null) { effectiveConjuncts.add(rewritten); } } } effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy(in(symbols)).getScopeEqualities()); return combineConjuncts(effectiveConjuncts.build()); } }