/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.execution.scheduler; import com.facebook.presto.execution.SqlStageExecution; import com.facebook.presto.execution.StageState; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanVisitor; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.UnionNode; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.jgrapht.DirectedGraph; import org.jgrapht.alg.StrongConnectivityInspector; import org.jgrapht.graph.DefaultDirectedGraph; import org.jgrapht.graph.DefaultEdge; import org.jgrapht.traverse.TopologicalOrderIterator; import javax.annotation.concurrent.NotThreadSafe; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import static com.facebook.presto.execution.StageState.RUNNING; import static com.facebook.presto.execution.StageState.SCHEDULED; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.util.function.Function.identity; @NotThreadSafe public class PhasedExecutionSchedule implements ExecutionSchedule { private final List<Set<SqlStageExecution>> schedulePhases; private final Set<SqlStageExecution> activeSources = new HashSet<>(); public PhasedExecutionSchedule(Collection<SqlStageExecution> stages) { List<Set<PlanFragmentId>> phases = extractPhases(stages.stream().map(SqlStageExecution::getFragment).collect(toImmutableList())); Map<PlanFragmentId, SqlStageExecution> stagesByFragmentId = stages.stream().collect(toImmutableMap(stage -> stage.getFragment().getId(), identity())); // create a mutable list of mutable sets of stages, so we can remove completed stages schedulePhases = new ArrayList<>(); for (Set<PlanFragmentId> phase : phases) { schedulePhases.add(phase.stream() .map(stagesByFragmentId::get) .collect(Collectors.toCollection(HashSet::new))); } } @Override public Set<SqlStageExecution> getStagesToSchedule() { removeCompletedStages(); addPhasesIfNecessary(); if (isFinished()) { return ImmutableSet.of(); } return activeSources; } private void removeCompletedStages() { for (Iterator<SqlStageExecution> stageIterator = activeSources.iterator(); stageIterator.hasNext(); ) { StageState state = stageIterator.next().getState(); if (state == SCHEDULED || state == RUNNING || state.isDone()) { stageIterator.remove(); } } } private void addPhasesIfNecessary() { // we want at least one source distributed phase in the active sources if (hasSourceDistributedStage(activeSources)) { return; } while (!schedulePhases.isEmpty()) { Set<SqlStageExecution> phase = schedulePhases.remove(0); activeSources.addAll(phase); if (hasSourceDistributedStage(phase)) { return; } } } private static boolean hasSourceDistributedStage(Set<SqlStageExecution> phase) { return phase.stream().anyMatch(stage -> !stage.getFragment().getPartitionedSources().isEmpty()); } @Override public boolean isFinished() { return activeSources.isEmpty() && schedulePhases.isEmpty(); } @VisibleForTesting static List<Set<PlanFragmentId>> extractPhases(Collection<PlanFragment> fragments) { // Build a graph where the plan fragments are vertexes and the edges represent // a before -> after relationship. For example, a join hash build has an edge // to the join probe. DirectedGraph<PlanFragmentId, DefaultEdge> graph = new DefaultDirectedGraph<>(DefaultEdge.class); fragments.forEach(fragment -> graph.addVertex(fragment.getId())); Visitor visitor = new Visitor(fragments, graph); for (PlanFragment fragment : fragments) { visitor.processFragment(fragment.getId()); } // Computes all the strongly connected components of the directed graph. // These are the "phases" which hold the set of fragments that must be started // at the same time to avoid deadlock. List<Set<PlanFragmentId>> components = new StrongConnectivityInspector<>(graph).stronglyConnectedSets(); Map<PlanFragmentId, Set<PlanFragmentId>> componentMembership = new HashMap<>(); for (Set<PlanFragmentId> component : components) { for (PlanFragmentId planFragmentId : component) { componentMembership.put(planFragmentId, component); } } // build graph of components (phases) DirectedGraph<Set<PlanFragmentId>, DefaultEdge> componentGraph = new DefaultDirectedGraph<>(DefaultEdge.class); components.forEach(componentGraph::addVertex); for (DefaultEdge edge : graph.edgeSet()) { PlanFragmentId source = graph.getEdgeSource(edge); PlanFragmentId target = graph.getEdgeTarget(edge); Set<PlanFragmentId> from = componentMembership.get(source); Set<PlanFragmentId> to = componentMembership.get(target); if (!from.equals(to)) { // the topological order iterator below doesn't include vertices that have self-edges, so don't add them componentGraph.addEdge(from, to); } } List<Set<PlanFragmentId>> schedulePhases = ImmutableList.copyOf(new TopologicalOrderIterator<>(componentGraph)); return schedulePhases; } private static class Visitor extends PlanVisitor<PlanFragmentId, Set<PlanFragmentId>> { private final Map<PlanFragmentId, PlanFragment> fragments; private final DirectedGraph<PlanFragmentId, DefaultEdge> graph; private final Map<PlanFragmentId, Set<PlanFragmentId>> fragmentSources = new HashMap<>(); public Visitor(Collection<PlanFragment> fragments, DirectedGraph<PlanFragmentId, DefaultEdge> graph) { this.fragments = fragments.stream() .collect(toImmutableMap(PlanFragment::getId, identity())); this.graph = graph; } public Set<PlanFragmentId> processFragment(PlanFragmentId planFragmentId) { return fragmentSources.computeIfAbsent(planFragmentId, fragmentId -> processFragment(fragments.get(fragmentId))); } private Set<PlanFragmentId> processFragment(PlanFragment fragment) { Set<PlanFragmentId> sources = fragment.getRoot().accept(this, fragment.getId()); return ImmutableSet.<PlanFragmentId>builder().add(fragment.getId()).addAll(sources).build(); } @Override public Set<PlanFragmentId> visitJoin(JoinNode node, PlanFragmentId currentFragmentId) { return processJoin(node.getRight(), node.getLeft(), currentFragmentId); } @Override public Set<PlanFragmentId> visitSemiJoin(SemiJoinNode node, PlanFragmentId currentFragmentId) { return processJoin(node.getFilteringSource(), node.getSource(), currentFragmentId); } @Override public Set<PlanFragmentId> visitIndexJoin(IndexJoinNode node, PlanFragmentId currentFragmentId) { return processJoin(node.getIndexSource(), node.getProbeSource(), currentFragmentId); } private Set<PlanFragmentId> processJoin(PlanNode build, PlanNode probe, PlanFragmentId currentFragmentId) { Set<PlanFragmentId> buildSources = build.accept(this, currentFragmentId); Set<PlanFragmentId> probeSources = probe.accept(this, currentFragmentId); for (PlanFragmentId buildSource : buildSources) { for (PlanFragmentId probeSource : probeSources) { graph.addEdge(buildSource, probeSource); } } return ImmutableSet.<PlanFragmentId>builder() .addAll(buildSources) .addAll(probeSources) .build(); } @Override public Set<PlanFragmentId> visitRemoteSource(RemoteSourceNode node, PlanFragmentId currentFragmentId) { ImmutableSet.Builder<PlanFragmentId> sources = ImmutableSet.builder(); Set<PlanFragmentId> previousFragmentSources = ImmutableSet.of(); for (PlanFragmentId remoteFragment : node.getSourceFragmentIds()) { // this current fragment depends on the remote fragment graph.addEdge(currentFragmentId, remoteFragment); // get all sources for the remote fragment Set<PlanFragmentId> remoteFragmentSources = processFragment(remoteFragment); sources.addAll(remoteFragmentSources); // For UNION there can be multiple sources. // Link the previous source to the current source, so we only // schedule one at a time. addEdges(previousFragmentSources, remoteFragmentSources); previousFragmentSources = remoteFragmentSources; } return sources.build(); } @Override public Set<PlanFragmentId> visitExchange(ExchangeNode node, PlanFragmentId currentFragmentId) { checkArgument(node.getScope() == LOCAL, "Only local exchanges are supported in the phased execution scheduler"); ImmutableSet.Builder<PlanFragmentId> allSources = ImmutableSet.builder(); // Link the source fragments together, so we only schedule one at a time. Set<PlanFragmentId> previousSources = ImmutableSet.of(); for (PlanNode subPlanNode : node.getSources()) { Set<PlanFragmentId> currentSources = subPlanNode.accept(this, currentFragmentId); allSources.addAll(currentSources); addEdges(previousSources, currentSources); previousSources = currentSources; } return allSources.build(); } @Override public Set<PlanFragmentId> visitUnion(UnionNode node, PlanFragmentId currentFragmentId) { ImmutableSet.Builder<PlanFragmentId> allSources = ImmutableSet.builder(); // Link the source fragments together, so we only schedule one at a time. Set<PlanFragmentId> previousSources = ImmutableSet.of(); for (PlanNode subPlanNode : node.getSources()) { Set<PlanFragmentId> currentSources = subPlanNode.accept(this, currentFragmentId); allSources.addAll(currentSources); addEdges(previousSources, currentSources); previousSources = currentSources; } return allSources.build(); } @Override protected Set<PlanFragmentId> visitPlan(PlanNode node, PlanFragmentId currentFragmentId) { List<PlanNode> sources = node.getSources(); if (sources.isEmpty()) { return ImmutableSet.of(currentFragmentId); } if (sources.size() == 1) { return sources.get(0).accept(this, currentFragmentId); } throw new UnsupportedOperationException("not yet implemented: " + node.getClass().getName()); } private void addEdges(Set<PlanFragmentId> sourceFragments, Set<PlanFragmentId> targetFragments) { for (PlanFragmentId targetFragment : targetFragments) { for (PlanFragmentId sourceFragment : sourceFragments) { graph.addEdge(sourceFragment, targetFragment); } } } } }