/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.airlift.concurrent; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.Uninterruptibles; import org.testng.Assert; import org.testng.annotations.Test; import javax.annotation.Nullable; import java.util.ArrayList; import java.util.List; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import static com.google.common.util.concurrent.Futures.addCallback; import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.testing.Assertions.assertLessThanOrEqual; import static java.util.concurrent.Executors.newCachedThreadPool; public class TestAsyncSemaphore { private final ListeningExecutorService executor = listeningDecorator(newCachedThreadPool(daemonThreadsNamed("async-semaphore-%s"))); @Test public void testInlineExecution() throws Exception { AsyncSemaphore<Runnable> asyncSemaphore = new AsyncSemaphore<>(1, executor, task -> newDirectExecutorService().submit(task)); AtomicInteger count = new AtomicInteger(); List<ListenableFuture<?>> futures = new ArrayList<>(); for (int i = 0; i < 1000; i++) { futures.add(asyncSemaphore.submit(count::incrementAndGet)); } // Wait for completion Futures.allAsList(futures).get(1, TimeUnit.MINUTES); Assert.assertEquals(count.get(), 1000); } @Test public void testSingleThreadBoundedConcurrency() throws Exception { AsyncSemaphore<Runnable> asyncSemaphore = new AsyncSemaphore<>(1, executor, executor::submit); AtomicInteger count = new AtomicInteger(); AtomicInteger concurrency = new AtomicInteger(); List<ListenableFuture<?>> futures = new ArrayList<>(); for (int i = 0; i < 1000; i++) { futures.add(asyncSemaphore.submit((Runnable) () -> { count.incrementAndGet(); int currentConcurrency = concurrency.incrementAndGet(); assertLessThanOrEqual(currentConcurrency, 1); Uninterruptibles.sleepUninterruptibly(1, TimeUnit.MILLISECONDS); concurrency.decrementAndGet(); })); } // Wait for completion Futures.allAsList(futures).get(1, TimeUnit.MINUTES); Assert.assertEquals(count.get(), 1000); } @Test public void testMultiThreadBoundedConcurrency() throws Exception { AsyncSemaphore<Runnable> asyncSemaphore = new AsyncSemaphore<>(2, executor, executor::submit); AtomicInteger count = new AtomicInteger(); AtomicInteger concurrency = new AtomicInteger(); List<ListenableFuture<?>> futures = new ArrayList<>(); for (int i = 0; i < 1000; i++) { futures.add(asyncSemaphore.submit(() -> { count.incrementAndGet(); int currentConcurrency = concurrency.incrementAndGet(); assertLessThanOrEqual(currentConcurrency, 2); Uninterruptibles.sleepUninterruptibly(1, TimeUnit.MILLISECONDS); concurrency.decrementAndGet(); })); } // Wait for completion Futures.allAsList(futures).get(1, TimeUnit.MINUTES); Assert.assertEquals(count.get(), 1000); } @Test public void testMultiSubmitters() throws Exception { AsyncSemaphore<Runnable> asyncSemaphore = new AsyncSemaphore<>(2, executor, executor::submit); AtomicInteger count = new AtomicInteger(); AtomicInteger concurrency = new AtomicInteger(); CountDownLatch startLatch = new CountDownLatch(1); CountDownLatch completionLatch = new CountDownLatch(100); for (int i = 0; i < 100; i++) { executor.execute(() -> { Uninterruptibles.awaitUninterruptibly(startLatch, 1, TimeUnit.MINUTES); asyncSemaphore.submit((Runnable) () -> { count.incrementAndGet(); int currentConcurrency = concurrency.incrementAndGet(); assertLessThanOrEqual(currentConcurrency, 2); Uninterruptibles.sleepUninterruptibly(1, TimeUnit.MILLISECONDS); concurrency.decrementAndGet(); completionLatch.countDown(); }); }); } // Start the submitters; startLatch.countDown(); // Wait for completion Uninterruptibles.awaitUninterruptibly(completionLatch, 1, TimeUnit.MINUTES); Assert.assertEquals(count.get(), 100); } @Test public void testFailedTasks() throws Exception { AsyncSemaphore<Runnable> asyncSemaphore = new AsyncSemaphore<>(2, executor, executor::submit); AtomicInteger successCount = new AtomicInteger(); AtomicInteger failureCount = new AtomicInteger(); AtomicInteger concurrency = new AtomicInteger(); CountDownLatch completionLatch = new CountDownLatch(1000); List<ListenableFuture<?>> futures = new ArrayList<>(); for (int i = 0; i < 1000; i++) { ListenableFuture<?> future = asyncSemaphore.submit(() -> assertFailedConcurrency(concurrency)); addCallback(future, completionCallback(successCount, failureCount, completionLatch)); futures.add(future); } // Wait for all tasks and callbacks to complete completionLatch.await(1, TimeUnit.MINUTES); for (ListenableFuture<?> future : futures) { try { future.get(); Assert.fail(); } catch (Exception ignored) { } } Assert.assertEquals(successCount.get(), 0); Assert.assertEquals(failureCount.get(), 1000); } @Test public void testFailedTaskSubmission() throws Exception { AtomicInteger successCount = new AtomicInteger(); AtomicInteger failureCount = new AtomicInteger(); AtomicInteger concurrency = new AtomicInteger(); CountDownLatch completionLatch = new CountDownLatch(1000); AsyncSemaphore<Runnable> asyncSemaphore = new AsyncSemaphore<>(2, executor, task -> { throw assertFailedConcurrency(concurrency); }); List<ListenableFuture<?>> futures = new ArrayList<>(); for (int i = 0; i < 1000; i++) { // Should never execute this future ListenableFuture<?> future = asyncSemaphore.submit(Assert::fail); addCallback(future, completionCallback(successCount, failureCount, completionLatch)); futures.add(future); } // Wait for all tasks and callbacks to complete completionLatch.await(1, TimeUnit.MINUTES); for (ListenableFuture<?> future : futures) { try { future.get(); Assert.fail(); } catch (Exception ignored) { } } Assert.assertEquals(successCount.get(), 0); Assert.assertEquals(failureCount.get(), 1000); } @Test public void testFailedTaskWithMultipleSubmitters() throws Exception { AtomicInteger successCount = new AtomicInteger(); AtomicInteger failureCount = new AtomicInteger(); AtomicInteger concurrency = new AtomicInteger(); CountDownLatch startLatch = new CountDownLatch(1); CountDownLatch completionLatch = new CountDownLatch(100); AsyncSemaphore<Runnable> asyncSemaphore = new AsyncSemaphore<>(2, executor, task -> { throw assertFailedConcurrency(concurrency); }); Queue<ListenableFuture<?>> futures = new ConcurrentLinkedQueue<>(); for (int i = 0; i < 100; i++) { executor.execute(() -> { Uninterruptibles.awaitUninterruptibly(startLatch, 1, TimeUnit.MINUTES); // Should never execute this future ListenableFuture<?> future = asyncSemaphore.submit(Assert::fail); futures.add(future); addCallback(future, completionCallback(successCount, failureCount, completionLatch)); }); } // Start the submitters; startLatch.countDown(); // Wait for completion Uninterruptibles.awaitUninterruptibly(completionLatch, 1, TimeUnit.MINUTES); // Make sure they all report failure for (ListenableFuture<?> future : futures) { try { future.get(); Assert.fail(); } catch (Exception ignored) { } } Assert.assertEquals(successCount.get(), 0); Assert.assertEquals(failureCount.get(), 100); } @Test public void testNoStackOverflow() throws Exception { AsyncSemaphore<Object> asyncSemaphore = new AsyncSemaphore<>(1, executor, object -> Futures.immediateFuture(null)); List<ListenableFuture<?>> futures = new ArrayList<>(); for (int i = 0; i < 1000; i++) { futures.add(asyncSemaphore.submit(new Object())); } // Wait for completion Futures.allAsList(futures).get(1, TimeUnit.MINUTES); } private static RuntimeException assertFailedConcurrency(AtomicInteger concurrency) { int currentConcurrency = concurrency.incrementAndGet(); assertLessThanOrEqual(currentConcurrency, 2); Uninterruptibles.sleepUninterruptibly(1, TimeUnit.MILLISECONDS); concurrency.decrementAndGet(); throw new IllegalStateException(); } private static FutureCallback<Object> completionCallback(AtomicInteger successCount, AtomicInteger failureCount, CountDownLatch completionLatch) { return new FutureCallback<Object>() { @Override public void onSuccess(@Nullable Object result) { successCount.incrementAndGet(); completionLatch.countDown(); } @Override public void onFailure(Throwable t) { failureCount.incrementAndGet(); completionLatch.countDown(); } }; } }