/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.execution; import com.facebook.presto.spi.Node; import com.facebook.presto.util.FinalizerService; import com.google.common.collect.Sets; import io.airlift.log.Logger; import javax.annotation.concurrent.ThreadSafe; import javax.inject.Inject; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.IntConsumer; import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; @ThreadSafe public class NodeTaskMap { private static final Logger log = Logger.get(NodeTaskMap.class); private final ConcurrentHashMap<Node, NodeTasks> nodeTasksMap = new ConcurrentHashMap<>(); private final FinalizerService finalizerService; @Inject public NodeTaskMap(FinalizerService finalizerService) { this.finalizerService = requireNonNull(finalizerService, "finalizerService is null"); } public void addTask(Node node, RemoteTask task) { createOrGetNodeTasks(node).addTask(task); } public int getPartitionedSplitsOnNode(Node node) { return createOrGetNodeTasks(node).getPartitionedSplitCount(); } public PartitionedSplitCountTracker createPartitionedSplitCountTracker(Node node, TaskId taskId) { return createOrGetNodeTasks(node).createPartitionedSplitCountTracker(taskId); } private NodeTasks createOrGetNodeTasks(Node node) { NodeTasks nodeTasks = nodeTasksMap.get(node); if (nodeTasks == null) { nodeTasks = addNodeTask(node); } return nodeTasks; } private NodeTasks addNodeTask(Node node) { NodeTasks newNodeTasks = new NodeTasks(finalizerService); NodeTasks nodeTasks = nodeTasksMap.putIfAbsent(node, newNodeTasks); if (nodeTasks == null) { return newNodeTasks; } return nodeTasks; } private static class NodeTasks { private final Set<RemoteTask> remoteTasks = Sets.newConcurrentHashSet(); private final AtomicInteger nodeTotalPartitionedSplitCount = new AtomicInteger(); private final FinalizerService finalizerService; public NodeTasks(FinalizerService finalizerService) { this.finalizerService = requireNonNull(finalizerService, "finalizerService is null"); } private int getPartitionedSplitCount() { return nodeTotalPartitionedSplitCount.get(); } private void addTask(RemoteTask task) { if (remoteTasks.add(task)) { task.addStateChangeListener(taskStatus -> { if (taskStatus.getState().isDone()) { remoteTasks.remove(task); } }); // Check if task state is already done before adding the listener if (task.getTaskStatus().getState().isDone()) { remoteTasks.remove(task); } } } public PartitionedSplitCountTracker createPartitionedSplitCountTracker(TaskId taskId) { requireNonNull(taskId, "taskId is null"); TaskPartitionedSplitCountTracker tracker = new TaskPartitionedSplitCountTracker(taskId); PartitionedSplitCountTracker partitionedSplitCountTracker = new PartitionedSplitCountTracker(tracker::setPartitionedSplitCount); // when partitionedSplitCountTracker is garbage collected, run the cleanup method on the tracker // Note: tracker can not have a reference to partitionedSplitCountTracker finalizerService.addFinalizer(partitionedSplitCountTracker, tracker::cleanup); return partitionedSplitCountTracker; } @ThreadSafe private class TaskPartitionedSplitCountTracker { private final TaskId taskId; private final AtomicInteger localPartitionedSplitCount = new AtomicInteger(); public TaskPartitionedSplitCountTracker(TaskId taskId) { this.taskId = requireNonNull(taskId, "taskId is null"); } public synchronized void setPartitionedSplitCount(int partitionedSplitCount) { if (partitionedSplitCount < 0) { int oldValue = localPartitionedSplitCount.getAndSet(0); nodeTotalPartitionedSplitCount.addAndGet(-oldValue); throw new IllegalArgumentException("partitionedSplitCount is negative"); } int oldValue = localPartitionedSplitCount.getAndSet(partitionedSplitCount); nodeTotalPartitionedSplitCount.addAndGet(partitionedSplitCount - oldValue); } public void cleanup() { int leakedSplits = localPartitionedSplitCount.getAndSet(0); if (leakedSplits == 0) { return; } log.error("BUG! %s for %s leaked with %s partitioned splits. Cleaning up so server can continue to function.", getClass().getName(), taskId, leakedSplits); nodeTotalPartitionedSplitCount.addAndGet(-leakedSplits); } @Override public String toString() { return toStringHelper(this) .add("taskId", taskId) .add("splits", localPartitionedSplitCount) .toString(); } } } public static class PartitionedSplitCountTracker { private final IntConsumer splitSetter; public PartitionedSplitCountTracker(IntConsumer splitSetter) { this.splitSetter = requireNonNull(splitSetter, "splitSetter is null"); } public void setPartitionedSplitCount(int partitionedSplitCount) { splitSetter.accept(partitionedSplitCount); } @Override public String toString() { return splitSetter.toString(); } } }