/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.DeterminismEvaluator; import com.facebook.presto.sql.planner.EffectivePredicateExtractor; import com.facebook.presto.sql.planner.EqualityInference; import com.facebook.presto.sql.planner.ExpressionInterpreter; import com.facebook.presto.sql.planner.ExpressionSymbolInliner; import com.facebook.presto.sql.planner.LiteralInterpreter; import com.facebook.presto.sql.planner.NoOpSymbolResolver; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.SortNode; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.UnionNode; import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.tree.BooleanLiteral; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.ComparisonExpressionType; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import io.airlift.log.Logger; import java.util.ArrayList; import java.util.Collection; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; import static com.facebook.presto.sql.ExpressionUtils.expressionOrNullSymbols; import static com.facebook.presto.sql.ExpressionUtils.extractConjuncts; import static com.facebook.presto.sql.ExpressionUtils.stripNonDeterministicConjuncts; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static com.facebook.presto.sql.planner.DeterminismEvaluator.isDeterministic; import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Predicates.equalTo; import static com.google.common.base.Predicates.in; import static com.google.common.base.Predicates.not; import static com.google.common.collect.Iterables.filter; import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class PredicatePushDown implements PlanOptimizer { private static final Logger log = Logger.get(PredicatePushDown.class); private final Metadata metadata; private final SqlParser sqlParser; public PredicatePushDown(Metadata metadata, SqlParser sqlParser) { this.metadata = requireNonNull(metadata, "metadata is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); } @Override public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { requireNonNull(plan, "plan is null"); requireNonNull(session, "session is null"); requireNonNull(types, "types is null"); requireNonNull(idAllocator, "idAllocator is null"); return SimplePlanRewriter.rewriteWith(new Rewriter(symbolAllocator, idAllocator, metadata, sqlParser, session, types), plan, BooleanLiteral.TRUE_LITERAL); } private static class Rewriter extends SimplePlanRewriter<Expression> { private final SymbolAllocator symbolAllocator; private final PlanNodeIdAllocator idAllocator; private final Metadata metadata; private final SqlParser sqlParser; private final Session session; private final Map<Symbol, Type> types; private final ExpressionEquivalence expressionEquivalence; private Rewriter( SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Metadata metadata, SqlParser sqlParser, Session session, Map<Symbol, Type> types) { this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); this.expressionEquivalence = new ExpressionEquivalence(metadata, sqlParser); } @Override public PlanNode visitPlan(PlanNode node, RewriteContext<Expression> context) { PlanNode rewrittenNode = context.defaultRewrite(node, BooleanLiteral.TRUE_LITERAL); if (!context.get().equals(BooleanLiteral.TRUE_LITERAL)) { // Drop in a FilterNode b/c we cannot push our predicate down any further rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, context.get()); } return rewrittenNode; } @Override public PlanNode visitExchange(ExchangeNode node, RewriteContext<Expression> context) { boolean modified = false; ImmutableList.Builder<PlanNode> builder = ImmutableList.builder(); for (int i = 0; i < node.getSources().size(); i++) { Map<Symbol, SymbolReference> outputsToInputs = new HashMap<>(); for (int index = 0; index < node.getInputs().get(i).size(); index++) { outputsToInputs.put( node.getOutputSymbols().get(index), node.getInputs().get(i).get(index).toSymbolReference()); } Expression sourcePredicate = new ExpressionSymbolInliner(outputsToInputs).rewrite(context.get()); PlanNode source = node.getSources().get(i); PlanNode rewrittenSource = context.rewrite(source, sourcePredicate); if (rewrittenSource != source) { modified = true; } builder.add(rewrittenSource); } if (modified) { return new ExchangeNode( node.getId(), node.getType(), node.getScope(), node.getPartitioningScheme(), builder.build(), node.getInputs()); } return node; } @Override public PlanNode visitProject(ProjectNode node, RewriteContext<Expression> context) { Set<Symbol> deterministicSymbols = node.getAssignments().entrySet().stream() .filter(entry -> DeterminismEvaluator.isDeterministic(entry.getValue())) .map(Map.Entry::getKey) .collect(Collectors.toSet()); Predicate<Expression> deterministic = conjunct -> DependencyExtractor.extractUnique(conjunct).stream() .allMatch(deterministicSymbols::contains); Map<Boolean, List<Expression>> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(deterministic)); // Push down conjuncts from the inherited predicate that don't depend on non-deterministic assignments PlanNode rewrittenNode = context.defaultRewrite(node, new ExpressionSymbolInliner(node.getAssignments().getMap()).rewrite(combineConjuncts(conjuncts.get(true)))); // All non-deterministic conjuncts, if any, will be in the filter node. if (!conjuncts.get(false).isEmpty()) { rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(conjuncts.get(false))); } return rewrittenNode; } @Override public PlanNode visitGroupId(GroupIdNode node, RewriteContext<Expression> context) { checkState(!DependencyExtractor.extractUnique(context.get()).contains(node.getGroupIdSymbol()), "groupId symbol cannot be referenced in predicate"); Map<Symbol, SymbolReference> commonGroupingSymbolMapping = node.getGroupingSetMappings().entrySet().stream() .filter(entry -> node.getCommonGroupingColumns().contains(entry.getKey())) .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); Predicate<Expression> pushdownEligiblePredicate = conjunct -> DependencyExtractor.extractUnique(conjunct).stream() .allMatch(commonGroupingSymbolMapping.keySet()::contains); Map<Boolean, List<Expression>> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(pushdownEligiblePredicate)); // Push down conjuncts from the inherited predicate that apply to common grouping symbols PlanNode rewrittenNode = context.defaultRewrite(node, new ExpressionSymbolInliner(commonGroupingSymbolMapping).rewrite(combineConjuncts(conjuncts.get(true)))); // All other conjuncts, if any, will be in the filter node. if (!conjuncts.get(false).isEmpty()) { rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(conjuncts.get(false))); } return rewrittenNode; } @Override public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext<Expression> context) { checkState(!DependencyExtractor.extractUnique(context.get()).contains(node.getMarkerSymbol()), "predicate depends on marker symbol"); return context.defaultRewrite(node, context.get()); } @Override public PlanNode visitSort(SortNode node, RewriteContext<Expression> context) { return context.defaultRewrite(node, context.get()); } @Override public PlanNode visitUnion(UnionNode node, RewriteContext<Expression> context) { boolean modified = false; ImmutableList.Builder<PlanNode> builder = ImmutableList.builder(); for (int i = 0; i < node.getSources().size(); i++) { Expression sourcePredicate = new ExpressionSymbolInliner(node.sourceSymbolMap(i)).rewrite(context.get()); PlanNode source = node.getSources().get(i); PlanNode rewrittenSource = context.rewrite(source, sourcePredicate); if (rewrittenSource != source) { modified = true; } builder.add(rewrittenSource); } if (modified) { return new UnionNode(node.getId(), builder.build(), node.getSymbolMapping(), node.getOutputSymbols()); } return node; } @Deprecated @Override public PlanNode visitFilter(FilterNode node, RewriteContext<Expression> context) { return context.rewrite(node.getSource(), combineConjuncts(node.getPredicate(), context.get())); } @Override public PlanNode visitJoin(JoinNode node, RewriteContext<Expression> context) { Expression inheritedPredicate = context.get(); // See if we can rewrite outer joins in terms of a plain inner join node = tryNormalizeToOuterToInnerJoin(node, inheritedPredicate); Expression leftEffectivePredicate = EffectivePredicateExtractor.extract(node.getLeft(), symbolAllocator.getTypes()); Expression rightEffectivePredicate = EffectivePredicateExtractor.extract(node.getRight(), symbolAllocator.getTypes()); Expression joinPredicate = extractJoinPredicate(node); Expression leftPredicate; Expression rightPredicate; Expression postJoinPredicate; Expression newJoinPredicate; switch (node.getType()) { case INNER: InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin(inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate, node.getLeft().getOutputSymbols()); leftPredicate = innerJoinPushDownResult.getLeftPredicate(); rightPredicate = innerJoinPushDownResult.getRightPredicate(); postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate(); newJoinPredicate = innerJoinPushDownResult.getJoinPredicate(); break; case LEFT: OuterJoinPushDownResult leftOuterJoinPushDownResult = processLimitedOuterJoin(inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate, node.getLeft().getOutputSymbols()); leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate(); rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate(); postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate(); newJoinPredicate = leftOuterJoinPushDownResult.getJoinPredicate(); break; case RIGHT: OuterJoinPushDownResult rightOuterJoinPushDownResult = processLimitedOuterJoin(inheritedPredicate, rightEffectivePredicate, leftEffectivePredicate, joinPredicate, node.getRight().getOutputSymbols()); leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate(); rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate(); postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate(); newJoinPredicate = rightOuterJoinPushDownResult.getJoinPredicate(); break; case FULL: leftPredicate = BooleanLiteral.TRUE_LITERAL; rightPredicate = BooleanLiteral.TRUE_LITERAL; postJoinPredicate = inheritedPredicate; newJoinPredicate = joinPredicate; break; default: throw new UnsupportedOperationException("Unsupported join type: " + node.getType()); } newJoinPredicate = simplifyExpression(newJoinPredicate); // TODO: find a better way to directly optimize FALSE LITERAL in join predicate if (newJoinPredicate.equals(BooleanLiteral.FALSE_LITERAL)) { newJoinPredicate = new ComparisonExpression(ComparisonExpressionType.EQUAL, new LongLiteral("0"), new LongLiteral("1")); } PlanNode leftSource = context.rewrite(node.getLeft(), leftPredicate); PlanNode rightSource = context.rewrite(node.getRight(), rightPredicate); PlanNode output = node; if (leftSource != node.getLeft() || rightSource != node.getRight() || !expressionEquivalence.areExpressionsEquivalent(session, newJoinPredicate, joinPredicate, types) || node.getCriteria().isEmpty()) { // Create identity projections for all existing symbols Assignments.Builder leftProjections = Assignments.builder(); leftProjections.putAll(node.getLeft() .getOutputSymbols().stream() .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference))); Assignments.Builder rightProjections = Assignments.builder(); rightProjections.putAll(node.getRight() .getOutputSymbols().stream() .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference))); // Create new projections for the new join clauses ImmutableList.Builder<JoinNode.EquiJoinClause> joinConditionBuilder = ImmutableList.builder(); ImmutableList.Builder<Expression> joinFilterBuilder = ImmutableList.builder(); for (Expression conjunct : extractConjuncts(newJoinPredicate)) { if (joinEqualityExpression(node.getLeft().getOutputSymbols()).test(conjunct)) { ComparisonExpression equality = (ComparisonExpression) conjunct; boolean alignedComparison = Iterables.all(DependencyExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols())); Expression leftExpression = (alignedComparison) ? equality.getLeft() : equality.getRight(); Expression rightExpression = (alignedComparison) ? equality.getRight() : equality.getLeft(); Symbol leftSymbol = symbolForExpression(leftExpression); if (!node.getLeft().getOutputSymbols().contains(leftSymbol)) { leftProjections.put(leftSymbol, leftExpression); } Symbol rightSymbol = symbolForExpression(rightExpression); if (!node.getRight().getOutputSymbols().contains(rightSymbol)) { rightProjections.put(rightSymbol, rightExpression); } joinConditionBuilder.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol)); } else { joinFilterBuilder.add(conjunct); } } Optional<Expression> newJoinFilter = Optional.of(combineConjuncts(joinFilterBuilder.build())); if (newJoinFilter.get() == BooleanLiteral.TRUE_LITERAL) { newJoinFilter = Optional.empty(); } leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build()); rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build()); output = createJoinNodeWithExpectedOutputs( node.getOutputSymbols(), idAllocator, node.getType(), leftSource, rightSource, newJoinFilter, joinConditionBuilder.build(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); } if (!postJoinPredicate.equals(BooleanLiteral.TRUE_LITERAL)) { output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate); } return output; } private Symbol symbolForExpression(Expression expression) { if (expression instanceof SymbolReference) { return Symbol.from(expression); } return symbolAllocator.newSymbol(expression, extractType(expression)); } private static PlanNode createJoinNodeWithExpectedOutputs( List<Symbol> expectedOutputs, PlanNodeIdAllocator idAllocator, JoinNode.Type type, PlanNode left, PlanNode right, Optional<Expression> filter, List<JoinNode.EquiJoinClause> conditions, Optional<Symbol> leftHashSymbol, Optional<Symbol> rightHashSymbol, Optional<JoinNode.DistributionType> distributionType) { // TODO: this should be removed once join nodes with output column pruning is supported for cross join if (conditions.isEmpty() && !filter.isPresent()) { PlanNode output = new JoinNode( idAllocator.getNextId(), type, left, right, conditions, ImmutableList.<Symbol>builder() .addAll(left.getOutputSymbols()) .addAll(right.getOutputSymbols()) .build(), filter, leftHashSymbol, rightHashSymbol, distributionType); if (!output.getOutputSymbols().equals(expectedOutputs)) { // Introduce a projection to constrain the outputs to what was originally expected // Some nodes are sensitive to what's produced (e.g., DistinctLimit node) output = new ProjectNode( idAllocator.getNextId(), output, Assignments.identity(expectedOutputs)); } return output; } else { return new JoinNode(idAllocator.getNextId(), type, left, right, conditions, expectedOutputs, filter, leftHashSymbol, rightHashSymbol, distributionType); } } private static OuterJoinPushDownResult processLimitedOuterJoin(Expression inheritedPredicate, Expression outerEffectivePredicate, Expression innerEffectivePredicate, Expression joinPredicate, Collection<Symbol> outerSymbols) { checkArgument(Iterables.all(DependencyExtractor.extractUnique(outerEffectivePredicate), in(outerSymbols)), "outerEffectivePredicate must only contain symbols from outerSymbols"); checkArgument(Iterables.all(DependencyExtractor.extractUnique(innerEffectivePredicate), not(in(outerSymbols))), "innerEffectivePredicate must not contain symbols from outerSymbols"); ImmutableList.Builder<Expression> outerPushdownConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> innerPushdownConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> postJoinConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> joinConjuncts = ImmutableList.builder(); // Strip out non-deterministic conjuncts postJoinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(DeterminismEvaluator::isDeterministic))); inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate); outerEffectivePredicate = stripNonDeterministicConjuncts(outerEffectivePredicate); innerEffectivePredicate = stripNonDeterministicConjuncts(innerEffectivePredicate); joinConjuncts.addAll(filter(extractConjuncts(joinPredicate), not(DeterminismEvaluator::isDeterministic))); joinPredicate = stripNonDeterministicConjuncts(joinPredicate); // Generate equality inferences EqualityInference inheritedInference = createEqualityInference(inheritedPredicate); EqualityInference outerInference = createEqualityInference(inheritedPredicate, outerEffectivePredicate); EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(outerSymbols)); Expression outerOnlyInheritedEqualities = combineConjuncts(equalityPartition.getScopeEqualities()); EqualityInference potentialNullSymbolInference = createEqualityInference(outerOnlyInheritedEqualities, outerEffectivePredicate, innerEffectivePredicate, joinPredicate); // See if we can push inherited predicates down for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { Expression outerRewritten = outerInference.rewriteExpression(conjunct, in(outerSymbols)); if (outerRewritten != null) { outerPushdownConjuncts.add(outerRewritten); // A conjunct can only be pushed down into an inner side if it can be rewritten in terms of the outer side Expression innerRewritten = potentialNullSymbolInference.rewriteExpression(outerRewritten, not(in(outerSymbols))); if (innerRewritten != null) { innerPushdownConjuncts.add(innerRewritten); } } else { postJoinConjuncts.add(conjunct); } } // Add the equalities from the inferences back in outerPushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); // See if we can push down any outer effective predicates to the inner side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(outerEffectivePredicate)) { Expression rewritten = potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerSymbols))); if (rewritten != null) { innerPushdownConjuncts.add(rewritten); } } // See if we can push down join predicates to the inner side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(joinPredicate)) { Expression innerRewritten = potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerSymbols))); if (innerRewritten != null) { innerPushdownConjuncts.add(innerRewritten); } else { joinConjuncts.add(conjunct); } } // Push outer and join equalities into the inner side. For example: // SELECT * FROM nation LEFT OUTER JOIN region ON nation.regionkey = region.regionkey and nation.name = region.name WHERE nation.name = 'blah' EqualityInference potentialNullSymbolInferenceWithoutInnerInferred = createEqualityInference(outerOnlyInheritedEqualities, outerEffectivePredicate, joinPredicate); innerPushdownConjuncts.addAll(potentialNullSymbolInferenceWithoutInnerInferred.generateEqualitiesPartitionedBy(not(in(outerSymbols))).getScopeEqualities()); // TODO: we can further improve simplifying the equalities by considering other relationships from the outer side EqualityInference.EqualityPartition joinEqualityPartition = createEqualityInference(joinPredicate).generateEqualitiesPartitionedBy(not(in(outerSymbols))); innerPushdownConjuncts.addAll(joinEqualityPartition.getScopeEqualities()); joinConjuncts.addAll(joinEqualityPartition.getScopeComplementEqualities()) .addAll(joinEqualityPartition.getScopeStraddlingEqualities()); return new OuterJoinPushDownResult(combineConjuncts(outerPushdownConjuncts.build()), combineConjuncts(innerPushdownConjuncts.build()), combineConjuncts(joinConjuncts.build()), combineConjuncts(postJoinConjuncts.build())); } private static class OuterJoinPushDownResult { private final Expression outerJoinPredicate; private final Expression innerJoinPredicate; private final Expression joinPredicate; private final Expression postJoinPredicate; private OuterJoinPushDownResult(Expression outerJoinPredicate, Expression innerJoinPredicate, Expression joinPredicate, Expression postJoinPredicate) { this.outerJoinPredicate = outerJoinPredicate; this.innerJoinPredicate = innerJoinPredicate; this.joinPredicate = joinPredicate; this.postJoinPredicate = postJoinPredicate; } private Expression getOuterJoinPredicate() { return outerJoinPredicate; } private Expression getInnerJoinPredicate() { return innerJoinPredicate; } public Expression getJoinPredicate() { return joinPredicate; } private Expression getPostJoinPredicate() { return postJoinPredicate; } } private static InnerJoinPushDownResult processInnerJoin(Expression inheritedPredicate, Expression leftEffectivePredicate, Expression rightEffectivePredicate, Expression joinPredicate, Collection<Symbol> leftSymbols) { checkArgument(Iterables.all(DependencyExtractor.extractUnique(leftEffectivePredicate), in(leftSymbols)), "leftEffectivePredicate must only contain symbols from leftSymbols"); checkArgument(Iterables.all(DependencyExtractor.extractUnique(rightEffectivePredicate), not(in(leftSymbols))), "rightEffectivePredicate must not contain symbols from leftSymbols"); ImmutableList.Builder<Expression> leftPushDownConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> rightPushDownConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> joinConjuncts = ImmutableList.builder(); // Strip out non-deterministic conjuncts joinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(DeterminismEvaluator::isDeterministic))); inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate); joinConjuncts.addAll(filter(extractConjuncts(joinPredicate), not(DeterminismEvaluator::isDeterministic))); joinPredicate = stripNonDeterministicConjuncts(joinPredicate); leftEffectivePredicate = stripNonDeterministicConjuncts(leftEffectivePredicate); rightEffectivePredicate = stripNonDeterministicConjuncts(rightEffectivePredicate); // Generate equality inferences EqualityInference allInference = createEqualityInference(inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate); EqualityInference allInferenceWithoutLeftInferred = createEqualityInference(inheritedPredicate, rightEffectivePredicate, joinPredicate); EqualityInference allInferenceWithoutRightInferred = createEqualityInference(inheritedPredicate, leftEffectivePredicate, joinPredicate); // Sort through conjuncts in inheritedPredicate that were not used for inference for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { Expression leftRewrittenConjunct = allInference.rewriteExpression(conjunct, in(leftSymbols)); if (leftRewrittenConjunct != null) { leftPushDownConjuncts.add(leftRewrittenConjunct); } Expression rightRewrittenConjunct = allInference.rewriteExpression(conjunct, not(in(leftSymbols))); if (rightRewrittenConjunct != null) { rightPushDownConjuncts.add(rightRewrittenConjunct); } // Drop predicate after join only if unable to push down to either side if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) { joinConjuncts.add(conjunct); } } // See if we can push the right effective predicate to the left side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(rightEffectivePredicate)) { Expression rewritten = allInference.rewriteExpression(conjunct, in(leftSymbols)); if (rewritten != null) { leftPushDownConjuncts.add(rewritten); } } // See if we can push the left effective predicate to the right side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(leftEffectivePredicate)) { Expression rewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols))); if (rewritten != null) { rightPushDownConjuncts.add(rewritten); } } // See if we can push any parts of the join predicates to either side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(joinPredicate)) { Expression leftRewritten = allInference.rewriteExpression(conjunct, in(leftSymbols)); if (leftRewritten != null) { leftPushDownConjuncts.add(leftRewritten); } Expression rightRewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols))); if (rightRewritten != null) { rightPushDownConjuncts.add(rightRewritten); } if (leftRewritten == null && rightRewritten == null) { joinConjuncts.add(conjunct); } } // Add equalities from the inference back in leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeEqualities()); rightPushDownConjuncts.addAll(allInferenceWithoutRightInferred.generateEqualitiesPartitionedBy(not(in(leftSymbols))).getScopeEqualities()); joinConjuncts.addAll(allInference.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as part of the join predicate return new InnerJoinPushDownResult(combineConjuncts(leftPushDownConjuncts.build()), combineConjuncts(rightPushDownConjuncts.build()), combineConjuncts(joinConjuncts.build()), BooleanLiteral.TRUE_LITERAL); } private static class InnerJoinPushDownResult { private final Expression leftPredicate; private final Expression rightPredicate; private final Expression joinPredicate; private final Expression postJoinPredicate; private InnerJoinPushDownResult(Expression leftPredicate, Expression rightPredicate, Expression joinPredicate, Expression postJoinPredicate) { this.leftPredicate = leftPredicate; this.rightPredicate = rightPredicate; this.joinPredicate = joinPredicate; this.postJoinPredicate = postJoinPredicate; } private Expression getLeftPredicate() { return leftPredicate; } private Expression getRightPredicate() { return rightPredicate; } private Expression getJoinPredicate() { return joinPredicate; } private Expression getPostJoinPredicate() { return postJoinPredicate; } } private static Expression extractJoinPredicate(JoinNode joinNode) { ImmutableList.Builder<Expression> builder = ImmutableList.builder(); for (JoinNode.EquiJoinClause equiJoinClause : joinNode.getCriteria()) { builder.add(equalsExpression(equiJoinClause.getLeft(), equiJoinClause.getRight())); } joinNode.getFilter().ifPresent(builder::add); return combineConjuncts(builder.build()); } private static Expression equalsExpression(Symbol symbol1, Symbol symbol2) { return new ComparisonExpression(ComparisonExpressionType.EQUAL, symbol1.toSymbolReference(), symbol2.toSymbolReference()); } private Type extractType(Expression expression) { return getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList() /* parameters have already been replaced */).get(expression); } private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression inheritedPredicate) { checkArgument(EnumSet.of(INNER, RIGHT, LEFT, FULL).contains(node.getType()), "Unsupported join type: %s", node.getType()); if (node.getType() == JoinNode.Type.INNER) { return node; } if (node.getType() == JoinNode.Type.FULL) { boolean canConvertToLeftJoin = canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate); boolean canConvertToRightJoin = canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate); if (!canConvertToLeftJoin && !canConvertToRightJoin) { return node; } if (canConvertToLeftJoin && canConvertToRightJoin) { return new JoinNode(node.getId(), INNER, node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); } else { return new JoinNode(node.getId(), canConvertToLeftJoin ? LEFT : RIGHT, node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); } } if (node.getType() == JoinNode.Type.LEFT && !canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate) || node.getType() == JoinNode.Type.RIGHT && !canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate)) { return node; } return new JoinNode(node.getId(), JoinNode.Type.INNER, node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); } private boolean canConvertOuterToInner(List<Symbol> innerSymbolsForOuterJoin, Expression inheritedPredicate) { Set<Symbol> innerSymbols = ImmutableSet.copyOf(innerSymbolsForOuterJoin); for (Expression conjunct : extractConjuncts(inheritedPredicate)) { if (DeterminismEvaluator.isDeterministic(conjunct)) { // Ignore a conjunct for this test if we can not deterministically get responses from it Object response = nullInputEvaluator(innerSymbols, conjunct); if (response == null || response instanceof NullLiteral || Boolean.FALSE.equals(response)) { // If there is a single conjunct that returns FALSE or NULL given all NULL inputs for the inner side symbols of an outer join // then this conjunct removes all effects of the outer join, and effectively turns this into an equivalent of an inner join. // So, let's just rewrite this join as an INNER join return true; } } } return false; } // Temporary implementation for joins because the SimplifyExpressions optimizers can not run properly on join clauses private Expression simplifyExpression(Expression expression) { IdentityLinkedHashMap<Expression, Type> expressionTypes = getExpressionTypes( session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList() /* parameters have already been replaced */); ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); return LiteralInterpreter.toExpression(optimizer.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(expression)); } /** * Evaluates an expression's response to binding the specified input symbols to NULL */ private Object nullInputEvaluator(final Collection<Symbol> nullSymbols, Expression expression) { IdentityLinkedHashMap<Expression, Type> expressionTypes = getExpressionTypes( session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList() /* parameters have already been replaced */); return ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes) .optimize(symbol -> nullSymbols.contains(symbol) ? null : symbol.toSymbolReference()); } private static Predicate<Expression> joinEqualityExpression(final Collection<Symbol> leftSymbols) { return expression -> { // At this point in time, our join predicates need to be deterministic if (isDeterministic(expression) && expression instanceof ComparisonExpression) { ComparisonExpression comparison = (ComparisonExpression) expression; if (comparison.getType() == ComparisonExpressionType.EQUAL) { Set<Symbol> symbols1 = DependencyExtractor.extractUnique(comparison.getLeft()); Set<Symbol> symbols2 = DependencyExtractor.extractUnique(comparison.getRight()); if (symbols1.isEmpty() || symbols2.isEmpty()) { return false; } return (Iterables.all(symbols1, in(leftSymbols)) && Iterables.all(symbols2, not(in(leftSymbols)))) || (Iterables.all(symbols2, in(leftSymbols)) && Iterables.all(symbols1, not(in(leftSymbols)))); } } return false; }; } @Override public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<Expression> context) { Expression inheritedPredicate = context.get(); Expression sourceEffectivePredicate = EffectivePredicateExtractor.extract(node.getSource(), symbolAllocator.getTypes()); List<Expression> sourceConjuncts = new ArrayList<>(); List<Expression> filteringSourceConjuncts = new ArrayList<>(); List<Expression> postJoinConjuncts = new ArrayList<>(); // TODO: see if there are predicates that can be inferred from the semi join output // Push inherited and source predicates to filtering source via a contrived join predicate (but needs to avoid touching NULL values in the filtering source) Expression joinPredicate = equalsExpression(node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol()); EqualityInference joinInference = createEqualityInference(inheritedPredicate, sourceEffectivePredicate, joinPredicate); for (Expression conjunct : Iterables.concat(EqualityInference.nonInferrableConjuncts(inheritedPredicate), EqualityInference.nonInferrableConjuncts(sourceEffectivePredicate))) { Expression rewrittenConjunct = joinInference.rewriteExpression(conjunct, equalTo(node.getFilteringSourceJoinSymbol())); if (rewrittenConjunct != null && DeterminismEvaluator.isDeterministic(rewrittenConjunct)) { // Alter conjunct to include an OR filteringSourceJoinSymbol IS NULL disjunct Expression rewrittenConjunctOrNull = expressionOrNullSymbols(Predicate.isEqual(node.getFilteringSourceJoinSymbol())).apply(rewrittenConjunct); filteringSourceConjuncts.add(rewrittenConjunctOrNull); } } EqualityInference.EqualityPartition joinInferenceEqualityPartition = joinInference.generateEqualitiesPartitionedBy(equalTo(node.getFilteringSourceJoinSymbol())); filteringSourceConjuncts.addAll(joinInferenceEqualityPartition.getScopeEqualities().stream() .map(expressionOrNullSymbols(Predicate.isEqual(node.getFilteringSourceJoinSymbol()))) .collect(Collectors.toList())); // Push inheritedPredicates down to the source if they don't involve the semi join output EqualityInference inheritedInference = createEqualityInference(inheritedPredicate); for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { Expression rewrittenConjunct = inheritedInference.rewriteExpression(conjunct, in(node.getSource().getOutputSymbols())); // Since each source row is reflected exactly once in the output, ok to push non-deterministic predicates down if (rewrittenConjunct != null) { sourceConjuncts.add(rewrittenConjunct); } else { postJoinConjuncts.add(conjunct); } } // Add the inherited equality predicates back in EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(node.getSource().getOutputSymbols())); sourceConjuncts.addAll(equalityPartition.getScopeEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(sourceConjuncts)); PlanNode rewrittenFilteringSource = context.rewrite(node.getFilteringSource(), combineConjuncts(filteringSourceConjuncts)); PlanNode output = node; if (rewrittenSource != node.getSource() || rewrittenFilteringSource != node.getFilteringSource()) { output = new SemiJoinNode(node.getId(), rewrittenSource, rewrittenFilteringSource, node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol(), node.getSemiJoinOutput(), node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol(), node.getDistributionType()); } if (!postJoinConjuncts.isEmpty()) { output = new FilterNode(idAllocator.getNextId(), output, combineConjuncts(postJoinConjuncts)); } return output; } @Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Expression> context) { if (node.getGroupingKeys().isEmpty()) { // cannot push predicates down through aggregations without any grouping columns return visitPlan(node, context); } Expression inheritedPredicate = context.get(); EqualityInference equalityInference = createEqualityInference(inheritedPredicate); List<Expression> pushdownConjuncts = new ArrayList<>(); List<Expression> postAggregationConjuncts = new ArrayList<>(); // Strip out non-deterministic conjuncts postAggregationConjuncts.addAll(ImmutableList.copyOf(filter(extractConjuncts(inheritedPredicate), not(DeterminismEvaluator::isDeterministic)))); inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate); // Sort non-equality predicates by those that can be pushed down and those that cannot for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { Expression rewrittenConjunct = equalityInference.rewriteExpression(conjunct, in(node.getGroupingKeys())); if (rewrittenConjunct != null) { pushdownConjuncts.add(rewrittenConjunct); } else { postAggregationConjuncts.add(conjunct); } } // Add the equality predicates back in EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(node.getGroupingKeys())); pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postAggregationConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postAggregationConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(pushdownConjuncts)); PlanNode output = node; if (rewrittenSource != node.getSource()) { output = new AggregationNode(node.getId(), rewrittenSource, node.getAggregations(), node.getFunctions(), node.getMasks(), node.getGroupingSets(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()); } if (!postAggregationConjuncts.isEmpty()) { output = new FilterNode(idAllocator.getNextId(), output, combineConjuncts(postAggregationConjuncts)); } return output; } @Override public PlanNode visitUnnest(UnnestNode node, RewriteContext<Expression> context) { Expression inheritedPredicate = context.get(); EqualityInference equalityInference = createEqualityInference(inheritedPredicate); List<Expression> pushdownConjuncts = new ArrayList<>(); List<Expression> postUnnestConjuncts = new ArrayList<>(); // Strip out non-deterministic conjuncts postUnnestConjuncts.addAll(ImmutableList.copyOf(filter(extractConjuncts(inheritedPredicate), not(DeterminismEvaluator::isDeterministic)))); inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate); // Sort non-equality predicates by those that can be pushed down and those that cannot for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { Expression rewrittenConjunct = equalityInference.rewriteExpression(conjunct, in(node.getReplicateSymbols())); if (rewrittenConjunct != null) { pushdownConjuncts.add(rewrittenConjunct); } else { postUnnestConjuncts.add(conjunct); } } // Add the equality predicates back in EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(node.getReplicateSymbols())); pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postUnnestConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postUnnestConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(pushdownConjuncts)); PlanNode output = node; if (rewrittenSource != node.getSource()) { output = new UnnestNode(node.getId(), rewrittenSource, node.getReplicateSymbols(), node.getUnnestSymbols(), node.getOrdinalitySymbol()); } if (!postUnnestConjuncts.isEmpty()) { output = new FilterNode(idAllocator.getNextId(), output, combineConjuncts(postUnnestConjuncts)); } return output; } @Override public PlanNode visitSample(SampleNode node, RewriteContext<Expression> context) { return context.defaultRewrite(node, context.get()); } @Override public PlanNode visitTableScan(TableScanNode node, RewriteContext<Expression> context) { Expression predicate = simplifyExpression(context.get()); if (!BooleanLiteral.TRUE_LITERAL.equals(predicate)) { return new FilterNode(idAllocator.getNextId(), node, predicate); } return node; } @Override public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext<Expression> context) { Set<Symbol> predicateSymbols = DependencyExtractor.extractUnique(context.get()); checkState(!predicateSymbols.contains(node.getIdColumn()), "UniqueId in predicate is not yet supported"); return context.defaultRewrite(node, context.get()); } } }