/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.concurrent; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import io.airlift.log.Logger; import javax.annotation.concurrent.GuardedBy; import java.util.Collection; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.FutureTask; import java.util.concurrent.PriorityBlockingQueue; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; /** * <p>Executes batches of tasks such that individual tasks within each batch are interleaved with tasks from other batches.</p> * <p/> * <p>E.g, if two batches containing elements ["a1, "a2", "a3", "a4"] and ["b1", "b2", "b3", "b4", "b5"] are submitted * in that order, a possible execution might be:</p> * <p/> * <p>"a1", "b1", "a2", "b2", "a3", "b3", "a4", "b4", "b5"</p> * <p/> * <p>If a third batch ["c1", "c2", "c3"] were submitted when the execution in the example above is at "a3", a possible * execution might be:</p> * <p/> * <p>"a1", "b1", "a2", "b2", "a3", "b3", "c1", "a4", "b4", "c2", "b5", "c3"</p> * <p/> * <p>The first task of a batch will not execute before the first task of a previously submitted task, therefore * guaranteeing that no batch will get starved.</p> */ public class FairBatchExecutor { private static final Logger log = Logger.get(FairBatchExecutor.class); private final AtomicBoolean shutdown = new AtomicBoolean(); private final int threads; private final ExecutorService executor; private final PriorityBlockingQueue<PrioritizedFutureTask> queue = new PriorityBlockingQueue<>(); @GuardedBy("this") private long basePriority; public FairBatchExecutor(int threads, ThreadFactory threadFactory) { this.threads = threads; this.executor = new ThreadPoolExecutor(threads, threads, 1, TimeUnit.MINUTES, new SynchronousQueue<Runnable>(), threadFactory, new ThreadPoolExecutor.DiscardPolicy()); } public void shutdown() { shutdown.set(true); executor.shutdown(); // poison pills for (int i = 0; i < threads; i++) { queue.add(new PrioritizedFutureTask<>(-1, new Callable<Void>() { @Override public Void call() throws Exception { return null; } })); } } // TODO: add shutdownNow public <T> List<FutureTask<T>> processBatch(Collection<? extends Callable<T>> tasks) { Preconditions.checkState(!shutdown.get(), "Executor is already shut down"); long priority = computeStartingPriority(); ImmutableList.Builder<FutureTask<T>> result = ImmutableList.builder(); for (Callable<T> task : tasks) { PrioritizedFutureTask<T> future = new PrioritizedFutureTask<>(priority++, task); queue.add(future); result.add(future); } // Make sure we have enough processors to achieve the desired concurrency level for (int i = 0; i < Math.min(threads, tasks.size()); ++i) { executor.execute(new Runnable() { @Override public void run() { trigger(); } }); } return result.build(); } private long computeStartingPriority() { synchronized (this) { // increment the base priority so that the first pending task // of previously submitted batches takes precedence basePriority++; return basePriority; } } private void updateStartingPriority(long newBase) { synchronized (this) { // update the base priority so that newly submitted batches are // interleaved correctly with tasks at the front of the queue if (basePriority < newBase) { basePriority = newBase; } } } private void trigger() { boolean interrupted = false; try { while (!Thread.currentThread().isInterrupted() && !shutdown.get()) { PrioritizedFutureTask<?> task = queue.take(); try { task.run(); } finally { updateStartingPriority(task.priority); } } } catch (InterruptedException e) { interrupted = true; } finally { if (!shutdown.get()) { // attempt to submit a new task in case we died due to unexpected reasons executor.execute(new Runnable() { @Override public void run() { trigger(); } }); } } if (interrupted) { Thread.currentThread().interrupt(); } } private static class PrioritizedFutureTask<T> extends FutureTask<T> implements Comparable<PrioritizedFutureTask> { private final long priority; private PrioritizedFutureTask(long priority, Callable<T> callable) { super(callable); this.priority = priority; } @Override public int compareTo(PrioritizedFutureTask o) { return Long.compare(priority, o.priority); } } }