/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.execution; import com.facebook.presto.OutputBuffers; import com.facebook.presto.Session; import com.facebook.presto.execution.StateMachine.StateChangeListener; import com.facebook.presto.execution.scheduler.SplitSchedulerStats; import com.facebook.presto.failureDetector.FailureDetector; import com.facebook.presto.metadata.RemoteTransactionHandle; import com.facebook.presto.metadata.Split; import com.facebook.presto.spi.Node; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.split.RemoteSplit; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimap; import io.airlift.units.Duration; import javax.annotation.concurrent.ThreadSafe; import java.net.URI; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static com.facebook.presto.failureDetector.FailureDetector.State.GONE; import static com.facebook.presto.operator.ExchangeOperator.REMOTE_CONNECTOR_ID; import static com.facebook.presto.spi.StandardErrorCode.REMOTE_HOST_GONE; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Sets.newConcurrentHashSet; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static java.util.Objects.requireNonNull; @ThreadSafe public final class SqlStageExecution { private final StageStateMachine stateMachine; private final RemoteTaskFactory remoteTaskFactory; private final NodeTaskMap nodeTaskMap; private final boolean summarizeTaskInfo; private final FailureDetector failureDetector; private final Map<PlanFragmentId, RemoteSourceNode> exchangeSources; private final Map<Node, Set<RemoteTask>> tasks = new ConcurrentHashMap<>(); private final AtomicInteger nextTaskId = new AtomicInteger(); private final Set<TaskId> allTasks = newConcurrentHashSet(); private final Set<TaskId> finishedTasks = newConcurrentHashSet(); private final AtomicBoolean splitsScheduled = new AtomicBoolean(); private final Multimap<PlanNodeId, URI> exchangeLocations = HashMultimap.create(); private final Set<PlanNodeId> completeSources = newConcurrentHashSet(); private final Set<PlanFragmentId> completeSourceFragments = newConcurrentHashSet(); private final AtomicReference<OutputBuffers> outputBuffers = new AtomicReference<>(); public SqlStageExecution( StageId stageId, URI location, PlanFragment fragment, RemoteTaskFactory remoteTaskFactory, Session session, boolean summarizeTaskInfo, NodeTaskMap nodeTaskMap, ExecutorService executor, FailureDetector failureDetector, SplitSchedulerStats schedulerStats) { this(new StageStateMachine( requireNonNull(stageId, "stageId is null"), requireNonNull(location, "location is null"), requireNonNull(session, "session is null"), requireNonNull(fragment, "fragment is null"), requireNonNull(executor, "executor is null"), requireNonNull(schedulerStats, "schedulerStats is null")), remoteTaskFactory, nodeTaskMap, summarizeTaskInfo, failureDetector); } public SqlStageExecution(StageStateMachine stateMachine, RemoteTaskFactory remoteTaskFactory, NodeTaskMap nodeTaskMap, boolean summarizeTaskInfo, FailureDetector failureDetector) { this.stateMachine = stateMachine; this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); this.summarizeTaskInfo = summarizeTaskInfo; this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); ImmutableMap.Builder<PlanFragmentId, RemoteSourceNode> fragmentToExchangeSource = ImmutableMap.builder(); for (RemoteSourceNode remoteSourceNode : stateMachine.getFragment().getRemoteSourceNodes()) { for (PlanFragmentId planFragmentId : remoteSourceNode.getSourceFragmentIds()) { fragmentToExchangeSource.put(planFragmentId, remoteSourceNode); } } this.exchangeSources = fragmentToExchangeSource.build(); } public StageId getStageId() { return stateMachine.getStageId(); } public StageState getState() { return stateMachine.getState(); } public void addStateChangeListener(StateChangeListener<StageState> stateChangeListener) { stateMachine.addStateChangeListener(stateChangeListener::stateChanged); } public PlanFragment getFragment() { return stateMachine.getFragment(); } public void beginScheduling() { stateMachine.transitionToScheduling(); } public synchronized void transitionToSchedulingSplits() { stateMachine.transitionToSchedulingSplits(); } public synchronized void schedulingComplete() { if (!stateMachine.transitionToScheduled()) { return; } if (getAllTasks().stream().anyMatch(task -> getState() == StageState.RUNNING)) { stateMachine.transitionToRunning(); } if (finishedTasks.containsAll(allTasks)) { stateMachine.transitionToFinished(); } for (PlanNodeId partitionedSource : stateMachine.getFragment().getPartitionedSources()) { for (RemoteTask task : getAllTasks()) { task.noMoreSplits(partitionedSource); } completeSources.add(partitionedSource); } } public synchronized void cancel() { stateMachine.transitionToCanceled(); getAllTasks().forEach(RemoteTask::cancel); } public synchronized void abort() { stateMachine.transitionToAborted(); getAllTasks().forEach(RemoteTask::abort); } public long getMemoryReservation() { return stateMachine.getMemoryReservation(); } public synchronized Duration getTotalCpuTime() { long millis = getAllTasks().stream() .mapToLong(task -> task.getTaskInfo().getStats().getTotalCpuTime().toMillis()) .sum(); return new Duration(millis, TimeUnit.MILLISECONDS); } public StageInfo getStageInfo() { return stateMachine.getStageInfo( () -> getAllTasks().stream() .map(RemoteTask::getTaskInfo) .collect(toImmutableList()), ImmutableList::of); } public synchronized void addExchangeLocations(PlanFragmentId fragmentId, Set<URI> exchangeLocations, boolean noMoreExchangeLocations) { requireNonNull(fragmentId, "fragmentId is null"); requireNonNull(exchangeLocations, "exchangeLocations is null"); RemoteSourceNode remoteSource = exchangeSources.get(fragmentId); checkArgument(remoteSource != null, "Unknown remote source %s. Known sources are %s", fragmentId, exchangeSources.keySet()); this.exchangeLocations.putAll(remoteSource.getId(), exchangeLocations); for (RemoteTask task : getAllTasks()) { ImmutableMultimap.Builder<PlanNodeId, Split> newSplits = ImmutableMultimap.builder(); for (URI exchangeLocation : exchangeLocations) { newSplits.put(remoteSource.getId(), createRemoteSplitFor(task.getTaskId(), exchangeLocation)); } task.addSplits(newSplits.build()); } if (noMoreExchangeLocations) { completeSourceFragments.add(fragmentId); // is the source now complete? if (completeSourceFragments.containsAll(remoteSource.getSourceFragmentIds())) { completeSources.add(remoteSource.getId()); for (RemoteTask task : getAllTasks()) { task.noMoreSplits(remoteSource.getId()); } } } } public synchronized void setOutputBuffers(OutputBuffers outputBuffers) { requireNonNull(outputBuffers, "outputBuffers is null"); while (true) { OutputBuffers currentOutputBuffers = this.outputBuffers.get(); if (currentOutputBuffers != null) { if (outputBuffers.getVersion() <= currentOutputBuffers.getVersion()) { return; } currentOutputBuffers.checkValidTransition(outputBuffers); } if (this.outputBuffers.compareAndSet(currentOutputBuffers, outputBuffers)) { for (RemoteTask task : getAllTasks()) { task.setOutputBuffers(outputBuffers); } return; } } } // do not synchronize // this is used for query info building which should be independent of scheduling work public boolean hasTasks() { return !tasks.isEmpty(); } // do not synchronize // this is used for query info building which should be independent of scheduling work public List<RemoteTask> getAllTasks() { return tasks.values().stream() .flatMap(Set::stream) .collect(toImmutableList()); } public synchronized RemoteTask scheduleTask(Node node, int partition) { requireNonNull(node, "node is null"); checkState(!splitsScheduled.get(), "scheduleTask can not be called once splits have been scheduled"); return scheduleTask(node, new TaskId(stateMachine.getStageId(), partition), ImmutableMultimap.of()); } public synchronized Set<RemoteTask> scheduleSplits(Node node, Multimap<PlanNodeId, Split> splits) { requireNonNull(node, "node is null"); requireNonNull(splits, "splits is null"); splitsScheduled.set(true); checkArgument(stateMachine.getFragment().getPartitionedSources().containsAll(splits.keySet()), "Invalid splits"); ImmutableSet.Builder<RemoteTask> newTasks = ImmutableSet.builder(); Collection<RemoteTask> tasks = this.tasks.get(node); if (tasks == null) { // The output buffer depends on the task id starting from 0 and being sequential, since each // task is assigned a private buffer based on task id. TaskId taskId = new TaskId(stateMachine.getStageId(), nextTaskId.getAndIncrement()); newTasks.add(scheduleTask(node, taskId, splits)); } else { RemoteTask task = tasks.iterator().next(); task.addSplits(splits); } return newTasks.build(); } private synchronized RemoteTask scheduleTask(Node node, TaskId taskId, Multimap<PlanNodeId, Split> sourceSplits) { checkArgument(!allTasks.contains(taskId), "A task with id %s already exists", taskId); ImmutableMultimap.Builder<PlanNodeId, Split> initialSplits = ImmutableMultimap.builder(); initialSplits.putAll(sourceSplits); for (Entry<PlanNodeId, URI> entry : exchangeLocations.entries()) { initialSplits.put(entry.getKey(), createRemoteSplitFor(taskId, entry.getValue())); } OutputBuffers outputBuffers = this.outputBuffers.get(); checkState(outputBuffers != null, "Initial output buffers must be set before a task can be scheduled"); RemoteTask task = remoteTaskFactory.createRemoteTask( stateMachine.getSession(), taskId, node, stateMachine.getFragment(), initialSplits.build(), outputBuffers, nodeTaskMap.createPartitionedSplitCountTracker(node, taskId), summarizeTaskInfo); completeSources.forEach(task::noMoreSplits); allTasks.add(taskId); tasks.computeIfAbsent(node, key -> newConcurrentHashSet()).add(task); nodeTaskMap.addTask(node, task); task.addStateChangeListener(new StageTaskListener()); if (!stateMachine.getState().isDone()) { task.start(); } else { // stage finished while we were scheduling this task task.abort(); } return task; } public Set<Node> getScheduledNodes() { return ImmutableSet.copyOf(tasks.keySet()); } public void recordGetSplitTime(long start) { stateMachine.recordGetSplitTime(start); } private static Split createRemoteSplitFor(TaskId taskId, URI taskLocation) { // Fetch the results from the buffer assigned to the task based on id URI splitLocation = uriBuilderFrom(taskLocation).appendPath("results").appendPath(String.valueOf(taskId.getId())).build(); return new Split(REMOTE_CONNECTOR_ID, new RemoteTransactionHandle(), new RemoteSplit(splitLocation)); } @Override public String toString() { return stateMachine.toString(); } private class StageTaskListener implements StateChangeListener<TaskStatus> { private long previousMemory; @Override public void stateChanged(TaskStatus taskStatus) { updateMemoryUsage(taskStatus); StageState stageState = getState(); if (stageState.isDone()) { return; } TaskState taskState = taskStatus.getState(); if (taskState == TaskState.FAILED) { RuntimeException failure = taskStatus.getFailures().stream() .findFirst() .map(this::rewriteTransportFailure) .map(ExecutionFailureInfo::toException) .orElse(new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason")); stateMachine.transitionToFailed(failure); } else if (taskState == TaskState.ABORTED) { // A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED) stateMachine.transitionToFailed(new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageState)); } else if (taskState == TaskState.FINISHED) { finishedTasks.add(taskStatus.getTaskId()); } if (stageState == StageState.SCHEDULED || stageState == StageState.RUNNING) { if (taskState == TaskState.RUNNING) { stateMachine.transitionToRunning(); } if (finishedTasks.containsAll(allTasks)) { stateMachine.transitionToFinished(); } } } private synchronized void updateMemoryUsage(TaskStatus taskStatus) { long currentMemory = taskStatus.getMemoryReservation().toBytes(); long deltaMemoryInBytes = currentMemory - previousMemory; previousMemory = currentMemory; stateMachine.updateMemoryUsage(deltaMemoryInBytes); } private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo) { if (executionFailureInfo.getRemoteHost() != null && failureDetector.getState(executionFailureInfo.getRemoteHost()) == GONE) { return new ExecutionFailureInfo( executionFailureInfo.getType(), executionFailureInfo.getMessage(), executionFailureInfo.getCause(), executionFailureInfo.getSuppressed(), executionFailureInfo.getStack(), executionFailureInfo.getErrorLocation(), REMOTE_HOST_GONE.toErrorCode(), executionFailureInfo.getRemoteHost() ); } else { return executionFailureInfo; } } } }