package games.strategy.thread;
import java.lang.ref.WeakReference;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.WeakHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Lock;
import com.google.common.annotations.VisibleForTesting;
/**
* Utility class for ensuring that locks are acquired in a consistent order.
*
* <p>
* Simply use this class and call acquireLock(aLock) releaseLock(aLock) instead of lock.lock(), lock.release(). If locks
* are acquired in an
* inconsistent order, an error message will be printed.
* </p>
*
* <p>
* This class is not terribly good for multithreading as it locks globally on all calls, but that is ok, as this code is
* meant more for when
* you are considering your ambitious multi-threaded code a mistake, and you are trying to limit the damage.
* </p>
*/
public enum LockUtil {
INSTANCE;
// the locks the current thread has
// because locks can be re-entrant, store this as a count
private final ThreadLocal<Map<Lock, Integer>> locksHeld = ThreadLocal.withInitial(() -> new HashMap<>());
// a map of all the locks ever held when a lock was acquired
// store weak references to everything so that locks don't linger here forever
private final Map<Lock, Set<WeakLockRef>> locksHeldWhenAcquired = new WeakHashMap<>();
private final Object mutex = new Object();
private final AtomicReference<ErrorReporter> errorReporterRef = new AtomicReference<>(new DefaultErrorReporter());
public void acquireLock(final Lock aLock) {
// we already have the lock, increase the count
if (isLockHeld(aLock)) {
final int current = locksHeld.get().get(aLock);
locksHeld.get().put(aLock, current + 1);
} else { // we don't have it
synchronized (mutex) {
// all the locks currently held must be acquired before a lock
if (!locksHeldWhenAcquired.containsKey(aLock)) {
locksHeldWhenAcquired.put(aLock, new HashSet<>());
}
for (final Lock l : locksHeld.get().keySet()) {
locksHeldWhenAcquired.get(aLock).add(new WeakLockRef(l));
}
// we are lock a, check to
// see if any lock we hold (b)
// has ever been acquired before a
for (final Lock l : locksHeld.get().keySet()) {
final Set<WeakLockRef> held = locksHeldWhenAcquired.get(l);
// clear out of date locks
final Iterator<WeakLockRef> iter = held.iterator();
while (iter.hasNext()) {
if (iter.next().get() == null) {
iter.remove();
}
}
if (held.contains(new WeakLockRef(aLock))) {
errorReporterRef.get().reportError(aLock, l);
}
}
}
locksHeld.get().put(aLock, 1);
}
aLock.lock();
}
public void releaseLock(final Lock aLock) {
int count = locksHeld.get().get(aLock);
count--;
if (count == 0) {
locksHeld.get().remove(aLock);
} else {
locksHeld.get().put(aLock, count);
}
aLock.unlock();
}
public boolean isLockHeld(final Lock aLock) {
return locksHeld.get().containsKey(aLock);
}
@VisibleForTesting
ErrorReporter setErrorReporter(final ErrorReporter errorReporter) {
return errorReporterRef.getAndSet(errorReporter);
}
@VisibleForTesting
interface ErrorReporter {
void reportError(Lock from, Lock to);
}
private static final class DefaultErrorReporter implements ErrorReporter {
@Override
public void reportError(final Lock from, final Lock to) {
System.err.println("Invalid lock ordering at, from:" + from + " to:" + to + " stack trace:" + getStackTrace());
}
private static String getStackTrace() {
final StackTraceElement[] trace = Thread.currentThread().getStackTrace();
final StringBuilder builder = new StringBuilder();
for (final StackTraceElement e : trace) {
builder.append(e.toString());
builder.append("\n");
}
return builder.toString();
}
}
private static final class WeakLockRef extends WeakReference<Lock> {
// cache the hash code to make sure it doesn't change if our reference
// has been cleared
private final int hashCode;
public WeakLockRef(final Lock referent) {
super(referent);
hashCode = referent.hashCode();
}
@Override
public boolean equals(final Object o) {
if (o == this) {
return true;
}
if (o instanceof WeakLockRef) {
final WeakLockRef other = (WeakLockRef) o;
return other.get() == this.get();
}
return false;
}
@Override
public int hashCode() {
return hashCode;
}
}
}