/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.execution; import com.facebook.presto.OutputBuffers; import com.facebook.presto.OutputBuffers.OutputBufferId; import com.facebook.presto.Session; import com.facebook.presto.TaskSource; import com.facebook.presto.execution.StateMachine.StateChangeListener; import com.facebook.presto.execution.buffer.BufferResult; import com.facebook.presto.execution.buffer.LazyOutputBuffer; import com.facebook.presto.execution.buffer.OutputBuffer; import com.facebook.presto.memory.QueryContext; import com.facebook.presto.operator.TaskContext; import com.facebook.presto.operator.TaskStats; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.base.Function; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.concurrent.SetThreadName; import io.airlift.log.Logger; import io.airlift.units.DataSize; import org.joda.time.DateTime; import javax.annotation.Nullable; import java.net.URI; import java.util.List; import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import static com.facebook.presto.util.Failures.toFailures; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateFuture; import static java.util.Objects.requireNonNull; public class SqlTask { private static final Logger log = Logger.get(SqlTask.class); private final TaskId taskId; private final String taskInstanceId; private final URI location; private final TaskStateMachine taskStateMachine; private final OutputBuffer outputBuffer; private final QueryContext queryContext; private final SqlTaskExecutionFactory sqlTaskExecutionFactory; private final AtomicReference<DateTime> lastHeartbeat = new AtomicReference<>(DateTime.now()); private final AtomicLong nextTaskInfoVersion = new AtomicLong(TaskStatus.STARTING_VERSION); private final AtomicReference<TaskHolder> taskHolderReference = new AtomicReference<>(new TaskHolder()); private final AtomicBoolean needsPlan = new AtomicBoolean(true); public SqlTask( TaskId taskId, URI location, QueryContext queryContext, SqlTaskExecutionFactory sqlTaskExecutionFactory, ExecutorService taskNotificationExecutor, final Function<SqlTask, ?> onDone, DataSize maxBufferSize) { this.taskId = requireNonNull(taskId, "taskId is null"); this.taskInstanceId = UUID.randomUUID().toString(); this.location = requireNonNull(location, "location is null"); this.queryContext = requireNonNull(queryContext, "queryContext is null"); this.sqlTaskExecutionFactory = requireNonNull(sqlTaskExecutionFactory, "sqlTaskExecutionFactory is null"); requireNonNull(taskNotificationExecutor, "taskNotificationExecutor is null"); requireNonNull(onDone, "onDone is null"); requireNonNull(maxBufferSize, "maxBufferSize is null"); outputBuffer = new LazyOutputBuffer(taskId, taskInstanceId, taskNotificationExecutor, maxBufferSize, new UpdateSystemMemory(queryContext)); taskStateMachine = new TaskStateMachine(taskId, taskNotificationExecutor); taskStateMachine.addStateChangeListener(new StateChangeListener<TaskState>() { @Override public void stateChanged(TaskState newState) { if (!newState.isDone()) { return; } // store final task info while (true) { TaskHolder taskHolder = taskHolderReference.get(); if (taskHolder.isFinished()) { // another concurrent worker already set the final state return; } if (taskHolderReference.compareAndSet(taskHolder, new TaskHolder(createTaskInfo(taskHolder), taskHolder.getIoStats()))) { break; } } // make sure buffers are cleaned up if (newState == TaskState.FAILED || newState == TaskState.ABORTED) { // don't close buffers for a failed query // closed buffers signal to upstream tasks that everything finished cleanly outputBuffer.fail(); } else { outputBuffer.destroy(); } try { onDone.apply(SqlTask.this); } catch (Exception e) { log.warn(e, "Error running task cleanup callback %s", SqlTask.this.taskId); } } }); } private static final class UpdateSystemMemory implements SystemMemoryUsageListener { private final QueryContext queryContext; public UpdateSystemMemory(QueryContext queryContext) { this.queryContext = requireNonNull(queryContext, "queryContext is null"); } @Override public void updateSystemMemoryUsage(long deltaMemoryInBytes) { if (deltaMemoryInBytes > 0) { queryContext.reserveSystemMemory(deltaMemoryInBytes); } else { queryContext.freeSystemMemory(-deltaMemoryInBytes); } } } public SqlTaskIoStats getIoStats() { return taskHolderReference.get().getIoStats(); } public TaskId getTaskId() { return taskStateMachine.getTaskId(); } public String getTaskInstanceId() { return taskInstanceId; } public void recordHeartbeat() { lastHeartbeat.set(DateTime.now()); } public TaskInfo getTaskInfo() { try (SetThreadName ignored = new SetThreadName("Task-%s", taskId)) { return createTaskInfo(taskHolderReference.get()); } } public TaskStatus getTaskStatus() { try (SetThreadName ignored = new SetThreadName("Task-%s", taskId)) { return createTaskStatus(taskHolderReference.get()); } } private TaskStatus createTaskStatus(TaskHolder taskHolder) { // Always return a new TaskInfo with a larger version number; // otherwise a client will not accept the update long versionNumber = nextTaskInfoVersion.getAndIncrement(); TaskState state = taskStateMachine.getState(); List<ExecutionFailureInfo> failures = ImmutableList.of(); if (state == TaskState.FAILED) { failures = toFailures(taskStateMachine.getFailureCauses()); } TaskStats taskStats = getTaskStats(taskHolder); return new TaskStatus(taskStateMachine.getTaskId(), taskInstanceId, versionNumber, state, location, failures, taskStats.getQueuedPartitionedDrivers(), taskStats.getRunningPartitionedDrivers(), taskStats.getMemoryReservation()); } private TaskStats getTaskStats(TaskHolder taskHolder) { TaskInfo finalTaskInfo = taskHolder.getFinalTaskInfo(); if (finalTaskInfo != null) { return finalTaskInfo.getStats(); } SqlTaskExecution taskExecution = taskHolder.getTaskExecution(); if (taskExecution != null) { return taskExecution.getTaskContext().getTaskStats(); } // if the task completed without creation, set end time DateTime endTime = taskStateMachine.getState().isDone() ? DateTime.now() : null; return new TaskStats(taskStateMachine.getCreatedTime(), endTime); } private static Set<PlanNodeId> getNoMoreSplits(TaskHolder taskHolder) { TaskInfo finalTaskInfo = taskHolder.getFinalTaskInfo(); if (finalTaskInfo != null) { return finalTaskInfo.getNoMoreSplits(); } SqlTaskExecution taskExecution = taskHolder.getTaskExecution(); if (taskExecution != null) { return taskExecution.getNoMoreSplits(); } return ImmutableSet.of(); } private TaskInfo createTaskInfo(TaskHolder taskHolder) { TaskStats taskStats = getTaskStats(taskHolder); Set<PlanNodeId> noMoreSplits = getNoMoreSplits(taskHolder); TaskStatus taskStatus = createTaskStatus(taskHolder); return new TaskInfo( taskStatus, lastHeartbeat.get(), outputBuffer.getInfo(), noMoreSplits, taskStats, needsPlan.get(), taskStatus.getState().isDone()); } public ListenableFuture<TaskStatus> getTaskStatus(TaskState callersCurrentState) { requireNonNull(callersCurrentState, "callersCurrentState is null"); if (callersCurrentState.isDone()) { return immediateFuture(getTaskInfo().getTaskStatus()); } ListenableFuture<TaskState> futureTaskState = taskStateMachine.getStateChange(callersCurrentState); return Futures.transform(futureTaskState, input -> getTaskInfo().getTaskStatus()); } public ListenableFuture<TaskInfo> getTaskInfo(TaskState callersCurrentState) { requireNonNull(callersCurrentState, "callersCurrentState is null"); // If the caller's current state is already done, just return the current // state of this task as it will either be done or possibly still running // (due to a bug in the caller), since we can not transition from a done // state. if (callersCurrentState.isDone()) { return immediateFuture(getTaskInfo()); } ListenableFuture<TaskState> futureTaskState = taskStateMachine.getStateChange(callersCurrentState); return Futures.transform(futureTaskState, input -> getTaskInfo()); } public TaskInfo updateTask(Session session, Optional<PlanFragment> fragment, List<TaskSource> sources, OutputBuffers outputBuffers) { try { // The LazyOutput buffer does not support write methods, so the actual // output buffer must be established before drivers are created (e.g. // a VALUES query). outputBuffer.setOutputBuffers(outputBuffers); // assure the task execution is only created once SqlTaskExecution taskExecution; synchronized (this) { // is task already complete? TaskHolder taskHolder = taskHolderReference.get(); if (taskHolder.isFinished()) { return taskHolder.getFinalTaskInfo(); } taskExecution = taskHolder.getTaskExecution(); if (taskExecution == null) { checkState(fragment.isPresent(), "fragment must be present"); taskExecution = sqlTaskExecutionFactory.create(session, queryContext, taskStateMachine, outputBuffer, fragment.get(), sources); taskHolderReference.compareAndSet(taskHolder, new TaskHolder(taskExecution)); needsPlan.set(false); } } if (taskExecution != null) { taskExecution.addSources(sources); } } catch (Error e) { failed(e); throw e; } catch (RuntimeException e) { failed(e); } return getTaskInfo(); } public ListenableFuture<BufferResult> getTaskResults(OutputBufferId bufferId, long startingSequenceId, DataSize maxSize) { requireNonNull(bufferId, "bufferId is null"); checkArgument(maxSize.toBytes() > 0, "maxSize must be at least 1 byte"); return outputBuffer.get(bufferId, startingSequenceId, maxSize); } public TaskInfo abortTaskResults(OutputBufferId bufferId) { requireNonNull(bufferId, "bufferId is null"); log.debug("Aborting task %s output %s", taskId, bufferId); outputBuffer.abort(bufferId); return getTaskInfo(); } public void failed(Throwable cause) { requireNonNull(cause, "cause is null"); taskStateMachine.failed(cause); } public TaskInfo cancel() { taskStateMachine.cancel(); return getTaskInfo(); } public TaskInfo abort() { taskStateMachine.abort(); return getTaskInfo(); } @Override public String toString() { return taskId.toString(); } private static final class TaskHolder { private final SqlTaskExecution taskExecution; private final TaskInfo finalTaskInfo; private final SqlTaskIoStats finalIoStats; private TaskHolder() { this.taskExecution = null; this.finalTaskInfo = null; this.finalIoStats = null; } private TaskHolder(SqlTaskExecution taskExecution) { this.taskExecution = requireNonNull(taskExecution, "taskExecution is null"); this.finalTaskInfo = null; this.finalIoStats = null; } private TaskHolder(TaskInfo finalTaskInfo, SqlTaskIoStats finalIoStats) { this.taskExecution = null; this.finalTaskInfo = requireNonNull(finalTaskInfo, "finalTaskInfo is null"); this.finalIoStats = requireNonNull(finalIoStats, "finalIoStats is null"); } public boolean isFinished() { return finalTaskInfo != null; } @Nullable public SqlTaskExecution getTaskExecution() { return taskExecution; } @Nullable public TaskInfo getFinalTaskInfo() { return finalTaskInfo; } public SqlTaskIoStats getIoStats() { // if we are finished, return the final IoStats if (finalIoStats != null) { return finalIoStats; } // if we haven't started yet, return an empty IoStats if (taskExecution == null) { return new SqlTaskIoStats(); } // get IoStats from the current task execution TaskContext taskContext = taskExecution.getTaskContext(); return new SqlTaskIoStats(taskContext.getInputDataSize(), taskContext.getInputPositions(), taskContext.getOutputDataSize(), taskContext.getOutputPositions()); } } public void addStateChangeListener(StateChangeListener<TaskState> stateChangeListener) { taskStateMachine.addStateChangeListener(stateChangeListener); } }