/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.PriorityQueue; import java.util.Set; import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class EliminateCrossJoins implements PlanOptimizer { @Override public PlanNode optimize( PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { if (!SystemSessionProperties.isJoinReorderingEnabled(session)) { return plan; } List<JoinGraph> joinGraphs = JoinGraph.buildFrom(plan); for (int i = joinGraphs.size() - 1; i >= 0; i--) { JoinGraph graph = joinGraphs.get(i); List<Integer> joinOrder = getJoinOrder(graph); if (isOriginalOrder(joinOrder)) { continue; } plan = rewriteWith(new Rewriter(idAllocator, graph, joinOrder), plan); } return plan; } public static boolean isOriginalOrder(List<Integer> joinOrder) { for (int i = 0; i < joinOrder.size(); i++) { if (joinOrder.get(i) != i) { return false; } } return true; } /** * Given JoinGraph determine the order of joins between graph nodes * by traversing JoinGraph. Any graph traversal algorithm could be used * here (like BFS or DFS), but we use PriorityQueue to preserve * original JoinOrder as mush as it is possible. PriorityQueue returns * next nodes to join in order of their occurrence in original Plan. */ public static List<Integer> getJoinOrder(JoinGraph graph) { ImmutableList.Builder<PlanNode> joinOrder = ImmutableList.builder(); Map<PlanNodeId, Integer> priorities = new HashMap<>(); for (int i = 0; i < graph.size(); i++) { priorities.put(graph.getNode(i).getId(), i); } PriorityQueue<PlanNode> nodesToVisit = new PriorityQueue<>( graph.size(), (Comparator<PlanNode>) (node1, node2) -> priorities.get(node1.getId()).compareTo(priorities.get(node2.getId()))); Set<PlanNode> visited = new HashSet<>(); nodesToVisit.add(graph.getNode(0)); while (!nodesToVisit.isEmpty()) { PlanNode node = nodesToVisit.poll(); if (!visited.contains(node)) { visited.add(node); joinOrder.add(node); for (JoinGraph.Edge edge : graph.getEdges(node)) { nodesToVisit.add(edge.getTargetNode()); } } if (nodesToVisit.isEmpty() && visited.size() < graph.size()) { // disconnected graph, find new starting point Optional<PlanNode> firstNotVisitedNode = graph.getNodes().stream() .filter(graphNode -> !visited.contains(graphNode)) .findFirst(); if (firstNotVisitedNode.isPresent()) { nodesToVisit.add(firstNotVisitedNode.get()); } } } checkState(visited.size() == graph.size()); return joinOrder.build().stream() .map(node -> priorities.get(node.getId())) .collect(toImmutableList()); } private class Rewriter extends SimplePlanRewriter<PlanNode> { private final PlanNodeIdAllocator idAllocator; private final JoinGraph graph; private final List<Integer> joinOrder; public Rewriter(PlanNodeIdAllocator idAllocator, JoinGraph graph, List<Integer> joinOrder) { this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.graph = requireNonNull(graph, "graph is null"); this.joinOrder = requireNonNull(joinOrder, "joinOrder is null"); checkState(joinOrder.size() >= 2); } @Override public PlanNode visitPlan(PlanNode node, RewriteContext<PlanNode> context) { if (node.getId() != graph.getRootId()) { return context.defaultRewrite(node, context.get()); } PlanNode result = graph.getNode(joinOrder.get(0)); Set<PlanNodeId> alreadyJoinedNodes = new HashSet<>(); alreadyJoinedNodes.add(result.getId()); for (int i = 1; i < joinOrder.size(); i++) { PlanNode rightNode = graph.getNode(joinOrder.get(i)); alreadyJoinedNodes.add(rightNode.getId()); ImmutableList.Builder<JoinNode.EquiJoinClause> criteria = ImmutableList.builder(); for (JoinGraph.Edge edge : graph.getEdges(rightNode)) { PlanNode targetNode = edge.getTargetNode(); if (alreadyJoinedNodes.contains(targetNode.getId())) { criteria.add(new JoinNode.EquiJoinClause( edge.getTargetSymbol(), edge.getSourceSymbol())); } } result = new JoinNode( idAllocator.getNextId(), JoinNode.Type.INNER, result, rightNode, criteria.build(), ImmutableList.<Symbol>builder() .addAll(result.getOutputSymbols()) .addAll(rightNode.getOutputSymbols()) .build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); } List<Expression> filters = graph.getFilters(); for (Expression filter : filters) { result = new FilterNode( idAllocator.getNextId(), result, filter); } if (graph.getAssignments().isPresent()) { result = new ProjectNode( idAllocator.getNextId(), result, Assignments.copyOf(graph.getAssignments().get())); } if (!result.getOutputSymbols().equals(node.getOutputSymbols())) { // Introduce a projection to constrain the outputs to what was originally expected // Some nodes are sensitive to what's produced (e.g., DistinctLimit node) result = new ProjectNode( idAllocator.getNextId(), result, Assignments.identity(node.getOutputSymbols())); } return result; } } }