/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.airlift.concurrent; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import javax.annotation.concurrent.ThreadSafe; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; /** * Guarantees that no more than maxPermits of tasks will be run concurrently. * The class will rely on the ListenableFuture returned by the submitter function to determine * when a task has been completed. The submitter function NEEDS to be thread-safe and is recommended * to do the bulk of its work asynchronously. */ @ThreadSafe public class AsyncSemaphore<T> { private final Queue<QueuedTask<T>> queuedTasks = new ConcurrentLinkedQueue<>(); private final AtomicInteger counter = new AtomicInteger(); private final Runnable runNextTask = this::runNext; private final int maxPermits; private final Executor submitExecutor; private final Function<T, ListenableFuture<?>> submitter; public AsyncSemaphore(int maxPermits, Executor submitExecutor, Function<T, ListenableFuture<?>> submitter) { checkArgument(maxPermits > 0, "must have at least one permit"); this.maxPermits = maxPermits; this.submitExecutor = requireNonNull(submitExecutor, "submitExecutor is null"); this.submitter = requireNonNull(submitter, "submitter is null"); } public ListenableFuture<?> submit(T task) { QueuedTask<T> queuedTask = new QueuedTask<>(task); queuedTasks.add(queuedTask); acquirePermit(); return queuedTask.getCompletionFuture(); } private void acquirePermit() { if (counter.incrementAndGet() <= maxPermits) { // Kick off a task if not all permits have been handed out submitExecutor.execute(runNextTask); } } private void releasePermit() { if (counter.getAndDecrement() > maxPermits) { // Now that a task has finished, we can kick off another task if there are more tasks than permits submitExecutor.execute(runNextTask); } } private void runNext() { final QueuedTask<T> queuedTask = queuedTasks.poll(); ListenableFuture<?> future = submitTask(queuedTask.getTask()); Futures.addCallback(future, new FutureCallback<Object>() { @Override public void onSuccess(Object result) { queuedTask.markCompleted(); releasePermit(); } @Override public void onFailure(Throwable t) { queuedTask.markFailure(t); releasePermit(); } }); } private ListenableFuture<?> submitTask(T task) { try { ListenableFuture<?> future = submitter.apply(task); if (future == null) { return Futures.immediateFailedFuture(new NullPointerException("Submitter returned a null future for task: " + task)); } return future; } catch (Exception e) { return Futures.immediateFailedFuture(e); } } private static class QueuedTask<T> { private final T task; private final SettableFuture<?> settableFuture = SettableFuture.create(); private QueuedTask(T task) { this.task = requireNonNull(task, "task is null"); } public T getTask() { return task; } public void markFailure(Throwable throwable) { settableFuture.setException(throwable); } public void markCompleted() { settableFuture.set(null); } public ListenableFuture<?> getCompletionFuture() { return settableFuture; } } }