package com.googlecode.totallylazy; import com.googlecode.totallylazy.functions.Function0; import com.googlecode.totallylazy.functions.Function1; import java.util.concurrent.Callable; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.AbstractQueuedSynchronizer; public class CountLatch { private final Sync sync; public CountLatch(int count) { this.sync = new Sync(count); } public CountLatch() { this(0); } public void await() throws InterruptedException { sync.acquireSharedInterruptibly(-1); } public boolean await(long timeout, TimeUnit unit) throws InterruptedException { return sync.tryAcquireSharedNanos(-1, unit.toNanos(timeout)); } public void countUp() { sync.releaseShared(1); } public void countDown() { sync.releaseShared(-1); } public int count() { return sync.count(); } public String toString() { return String.format("Latch(%d)", sync.count()); } public <A> Function0<A> monitor(Callable<? extends A> callable) { return monitor(callable, this); } public static <A> Function0<A> monitor(final Callable<? extends A> callable, final CountLatch latch) { return () -> { latch.countUp(); try { return callable.call(); } finally { latch.countDown(); } }; } public <A, B> Function1<A, B> monitor(Function1<? super A, ? extends B> callable) { return monitor(this, callable); } public static <A, B> Function1<A, B> monitor(final CountLatch latch, final Function1<? super A, ? extends B> callable) { return a -> { latch.countUp(); try { return callable.call(a); } finally { latch.countDown(); } }; } private static final class Sync extends AbstractQueuedSynchronizer { public Sync(int count) { setState(count); } protected final int tryAcquireShared(int ignore) { return finished() ? 1 : -1; } protected final boolean tryReleaseShared(int adjust) { while (true) { int oldValue = count(); int newValue = oldValue + adjust; if (compareAndSetState(oldValue, newValue)) return finished(); } } final int count() { return getState(); } final boolean finished() { return count() == 0; } } }