package io.airlift.concurrent;
import com.google.common.util.concurrent.Uninterruptibles;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
public class TestBoundedExecutor
{
private ExecutorService executorService;
@BeforeClass
public void setUp()
throws Exception
{
executorService = Executors.newCachedThreadPool();
}
@AfterClass(alwaysRun = true)
public void tearDown()
throws Exception
{
executorService.shutdownNow();
}
@Test
public void testCounter()
throws Exception
{
int maxThreads = 1;
BoundedExecutor boundedExecutor = new BoundedExecutor(executorService, maxThreads); // Enforce single thread
int stageTasks = 100_000;
int totalTasks = stageTasks * 2;
AtomicInteger counter = new AtomicInteger();
CountDownLatch initializeLatch = new CountDownLatch(maxThreads);
CountDownLatch startLatch = new CountDownLatch(1);
CountDownLatch completeLatch = new CountDownLatch(totalTasks);
// Pre-loaded tasks
for (int i = 0; i < stageTasks; i++) {
boundedExecutor.execute(() -> {
try {
initializeLatch.countDown();
Uninterruptibles.awaitUninterruptibly(startLatch); // Wait for the go signal
// Intentional distinct read and write calls
int initialCount = counter.get();
counter.set(initialCount + 1);
}
finally {
completeLatch.countDown();
}
});
}
assertTrue(Uninterruptibles.awaitUninterruptibly(initializeLatch, 1, TimeUnit.MINUTES)); // Wait for pre-load tasks to initialize
startLatch.countDown(); // Signal go for stage1 threads
// Concurrently submitted tasks
for (int i = 0; i < stageTasks; i++) {
boundedExecutor.execute(() -> {
try {
// Intentional distinct read and write calls
int initialCount = counter.get();
counter.set(initialCount + 1);
}
finally {
completeLatch.countDown();
}
});
}
assertTrue(Uninterruptibles.awaitUninterruptibly(completeLatch, 1, TimeUnit.MINUTES)); // Wait for tasks to complete
Assert.assertEquals(counter.get(), totalTasks);
}
@Test
public void testSingleThreadBound()
throws Exception
{
testBound(1, 100_000);
}
@Test
public void testDoubleThreadBound()
throws Exception
{
testBound(2, 100_000);
}
@Test
public void testTripleThreadBound()
throws Exception
{
testBound(3, 100_000);
}
@Test
public void testExecutorCorruptionDetection()
throws Exception
{
AtomicBoolean reject = new AtomicBoolean();
Executor executor = command -> {
if (reject.get()) {
throw new RejectedExecutionException();
}
executorService.execute(command);
};
BoundedExecutor boundedExecutor = new BoundedExecutor(executor, 1); // Enforce single thread
// Force the underlying executor to fail
reject.set(true);
try {
boundedExecutor.execute(() -> fail("Should not be run"));
fail("Execute should fail");
}
catch (Exception e) {
}
// Recover the underlying executor, but all new tasks should fail
reject.set(false);
try {
boundedExecutor.execute(() -> fail("Should not be run"));
fail("Execute should still fail");
}
catch (Exception e) {
}
}
private void testBound(final int maxThreads, int stageTasks)
{
BoundedExecutor boundedExecutor = new BoundedExecutor(executorService, maxThreads);
int totalTasks = stageTasks * 2;
AtomicInteger activeThreadCount = new AtomicInteger();
CountDownLatch initializeLatch = new CountDownLatch(maxThreads);
CountDownLatch startLatch = new CountDownLatch(1);
CountDownLatch completeLatch = new CountDownLatch(totalTasks);
AtomicBoolean failed = new AtomicBoolean();
// Pre-loaded tasks
for (int i = 0; i < stageTasks; i++) {
boundedExecutor.execute(() -> {
try {
initializeLatch.countDown();
Uninterruptibles.awaitUninterruptibly(startLatch); // Wait for the go signal
int count = activeThreadCount.incrementAndGet();
if (count < 1 || count > maxThreads) {
failed.set(true);
}
activeThreadCount.decrementAndGet();
}
finally {
completeLatch.countDown();
}
});
}
assertTrue(Uninterruptibles.awaitUninterruptibly(initializeLatch, 1, TimeUnit.MINUTES)); // Wait for pre-load tasks to initialize
startLatch.countDown(); // Signal go for stage1 threads
// Concurrently submitted tasks
for (int i = 0; i < stageTasks; i++) {
boundedExecutor.execute(() -> {
try {
int count = activeThreadCount.incrementAndGet();
if (count < 1 || count > maxThreads) {
failed.set(true);
}
activeThreadCount.decrementAndGet();
}
finally {
completeLatch.countDown();
}
});
}
assertTrue(Uninterruptibles.awaitUninterruptibly(completeLatch, 1, TimeUnit.MINUTES)); // Wait for tasks to complete
Assert.assertFalse(failed.get());
}
}