package games.strategy.util;
import java.io.Serializable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
/**
* This synchronization aid is very similar to {@link CountDownLatch},
* except that you can increment the latch.
* Implements AQS behind the scenes similar to CountDownLatch.
* Class is hobbled together from various learnings and tickets on stackexchange/stackoverflow.
*/
public class CountUpAndDownLatch implements Serializable {
private static final long serialVersionUID = -1656388212821764097L;
private final Sync sync;
private int originalCount;
/**
* Constructs a {@link CountUpAndDownLatch} initialized with zero.
*/
public CountUpAndDownLatch() {
sync = new Sync();
}
/**
* Constructs a {@link CountUpAndDownLatch} initialized with the given count.
*
* @param initialCount
* the number of times {@link #countDown} must be invoked before threads can pass through {@link #await}
* @throws IllegalArgumentException
* if {@code count} is negative
*/
public CountUpAndDownLatch(final int initialCount) {
if (initialCount < 0) {
throw new IllegalArgumentException("count < 0");
}
sync = new Sync(initialCount);
originalCount = initialCount;
}
/**
* Increment the count by one.
*/
public void increment() {
sync.releaseShared(1);
}
/**
* Decrements the count of the latch, releasing all waiting threads if the count reaches zero.
*
* @see CountDownLatch#countDown()
*/
public void countDown() {
sync.releaseShared(-1);
}
/**
* @see CountDownLatch#countDown()
* @param delta
* the amount to increment (or if negative, decrement countDown).
*/
public void applyDelta(final int delta) {
sync.releaseShared(delta);
}
/**
* countDown to zero.
*/
public void releaseAll() {
applyDelta(Integer.MIN_VALUE);
}
/**
* Reset the latch to its original count.
*/
public void resetCount() {
if (originalCount == 0) {
releaseAll();
} else {
final int diff = originalCount - sync.getCount();
applyDelta(diff);
}
}
/**
* Returns the current count.
*
* @see CountDownLatch#getCount()
*/
public int getCount() {
return sync.getCount();
}
/**
* @return The original count this latch was created with.
*/
public int getOriginalCount() {
return originalCount;
}
/**
* Causes the current thread to wait until the latch has counted down to zero, unless the thread is interrupted.
*
* @see CountDownLatch#await()
*/
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
/**
* Causes the current thread to wait until the latch has counted down to zero, unless the thread is interrupted, or
* the specified waiting time elapses.
*
* @see CountDownLatch#await(long,TimeUnit)
*/
public boolean await(final long timeout, final TimeUnit unit) throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
/**
* Returns a string identifying this latch, as well as its state.
* The state, in brackets, includes the String "Count =" followed by the current count.
*/
@Override
public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}
/**
* Synchronization control for CountingLatch.
* Uses AQS state to represent count.
*/
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = -7639904478060101736L;
private Sync() {}
private Sync(final int initialState) {
setState(initialState);
}
int getCount() {
return getState();
}
@Override
protected int tryAcquireShared(final int acquires) {
return getState() == 0 ? 1 : -1;
}
@Override
protected boolean tryReleaseShared(final int delta) {
if (delta == 0) {
return false;
}
// Decrement count; signal when transition to zero
for (;;) {
final int c = getState();
int nextc = c + delta;
if (c <= 0 && nextc <= 0) {
return false;
}
if (nextc < 0) {
nextc = 0;
}
if (compareAndSetState(c, nextc)) {
return nextc == 0;
}
}
}
}
}