package com.linkedin.d2.balancer.util; import com.linkedin.common.callback.Function; import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import static org.testng.Assert.*; /** * Miscellaneous code that's useful for testing. * Any of this code may throw AssertionError when something unexpected happens. */ public class TestHelper { /** * Assert that actual and expected contain the same objects in the same order. */ public static <T> void assertSameElements(Iterable<T> actual, Iterable<T> expected) { Iterator<T> e = expected.iterator(); int index = 0; for (T a : actual) { assertTrue(e.hasNext(), "too long: " + actual + ".size > " + expected + ".size"); assertSame(a, e.next(), "not same: actual[" + index + "]"); ++index; } assertFalse(e.hasNext(), "too short: " + actual + ".size < " + expected + ".size"); } /** * Partition from into subLists, each of size subListSize. */ public static <T> List<List<T>> split(List<T> from, int subListSize) { List<List<T>> into = new ArrayList<List<T>>(); for (int first = 0; first < from.size(); first += subListSize) { into.add(from.subList(first, Math.min(first + subListSize, from.size()))); } return into; } public static <T> List<T> getAll(Collection<Future<T>> futures) { return getAll(futures, futures.size() * 10, TimeUnit.SECONDS); // plenty of time } public static <T> List<T> getAll(Iterable<Future<T>> futures, long timeout, TimeUnit unit) { List<T> all = new ArrayList<T>(); final long deadline = System.nanoTime() + unit.toNanos(timeout); int f = 0; for (Future<T> future : futures) { try { all.add(future.get(deadline - System.nanoTime(), TimeUnit.NANOSECONDS)); } catch (Exception e) { fail("index " + f, e); } ++f; } return all; } public static <T> List<Future<T>> concurrently(Collection<Callable<T>> calls) { final int numberOfCalls = calls.size(); CountDownLatch ready = new CountDownLatch(numberOfCalls); CountDownLatch start = new CountDownLatch(1); List<Future<T>> futures = new ArrayList<Future<T>>(numberOfCalls); { ExecutorService pool = newFixedDaemonPool(numberOfCalls); for (Callable<T> call : calls) futures.add(pool.submit(new PauseCallable<T>(1, ready, start, call))); assertEquals(futures.size(), numberOfCalls); } try { assertTrue(ready.await(numberOfCalls * 10, TimeUnit.SECONDS)); } catch (InterruptedException e) { fail(e + "", e); } start.countDown(); // allow all threads to proceed return futures; } public static ExecutorService newFixedDaemonPool(int numberOfThreads) { return Executors.newFixedThreadPool(numberOfThreads, new DaemonFactory()); } public static class DaemonFactory implements ThreadFactory { private static final AtomicLong factoryNumbers = new AtomicLong(0); private final long factoryNumber = factoryNumbers.incrementAndGet(); private final AtomicLong threadNumbers = new AtomicLong(0); @Override public Thread newThread(Runnable target) { Thread thread = new Thread(target); thread.setDaemon(true); // Structured thread names are helpful for debugging. thread.setName("daemon-" + factoryNumber + "." + threadNumbers.incrementAndGet()); return thread; } } /** * A Callable that pauses execution of its calling threads. */ public static class PauseCallable<T> implements Callable<T> { private final long _pauseCall; private final CountDownLatch _paused; private final CountDownLatch _resume; private final Callable<T> _target; private final AtomicLong _calls = new AtomicLong(0); public PauseCallable(long pauseCall, CountDownLatch paused, CountDownLatch resume, Callable<T> target) { _pauseCall = pauseCall; _paused = paused; _resume = resume; _target = target; } /** The number of times this object has been called. */ public long getCalls() { return _calls.get(); } @Override public T call() throws Exception { if (_calls.incrementAndGet() >= _pauseCall) { _paused.countDown(); _resume.await(); } return _target.call(); } } }