package com.google.inject.internal; import com.google.common.collect.ImmutableList; import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimaps; import com.google.inject.internal.CycleDetectingLock.CycleDetectingLockFactory; import com.google.inject.internal.CycleDetectingLock.CycleDetectingLockFactory.ReentrantCycleDetectingLock; import java.util.Collection; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.FutureTask; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.ReentrantLock; import junit.framework.TestCase; public class CycleDetectingLockTest extends TestCase { static final long DEADLOCK_TIMEOUT_SECONDS = 1; /** * Verifies that graph of threads' dependencies is not static and is calculated in runtime using * information about specific locks. * * <pre> * T1: Waits on S1 * T2: Locks B, sends S1, waits on S2 * T1: Locks A, start locking B, sends S2, waits on S3 * T2: Unlocks B, start locking A, sends S3, finishes locking A, unlocks A * T1: Finishes locking B, unlocks B, unlocks A * </pre> * * <p>This should succeed, even though T1 was locked on T2 and T2 is locked on T1 when T2 locks A. * Incorrect implementation detects a cycle waiting on S3. */ public void testSingletonThreadsRuntimeCircularDependency() throws Exception { final CyclicBarrier signal1 = new CyclicBarrier(2); final CyclicBarrier signal2 = new CyclicBarrier(2); final CyclicBarrier signal3 = new CyclicBarrier(2); final CycleDetectingLockFactory<String> lockFactory = new CycleDetectingLockFactory<String>(); final CycleDetectingLock<String> lockA = new ReentrantCycleDetectingLock<String>( lockFactory, "A", new ReentrantLock() { @Override public void lock() { if (Thread.currentThread().getName().equals("T2")) { try { signal3.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS); } catch (Exception e) { throw new RuntimeException(e); } } else { assertEquals("T1", Thread.currentThread().getName()); } super.lock(); } }); final CycleDetectingLock<String> lockB = new ReentrantCycleDetectingLock<String>( lockFactory, "B", new ReentrantLock() { @Override public void lock() { if (Thread.currentThread().getName().equals("T1")) { try { signal2.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS); signal3.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS); } catch (Exception e) { throw new RuntimeException(e); } } else { assertEquals("T2", Thread.currentThread().getName()); } super.lock(); } }); Future<Void> firstThreadResult = Executors.newSingleThreadExecutor() .submit( new Callable<Void>() { @Override public Void call() throws Exception { Thread.currentThread().setName("T1"); signal1.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS); assertTrue(lockA.lockOrDetectPotentialLocksCycle().isEmpty()); assertTrue(lockB.lockOrDetectPotentialLocksCycle().isEmpty()); lockB.unlock(); lockA.unlock(); return null; } }); Future<Void> secondThreadResult = Executors.newSingleThreadExecutor() .submit( new Callable<Void>() { @Override public Void call() throws Exception { Thread.currentThread().setName("T2"); assertTrue(lockB.lockOrDetectPotentialLocksCycle().isEmpty()); signal1.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS); signal2.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS); lockB.unlock(); assertTrue(lockA.lockOrDetectPotentialLocksCycle().isEmpty()); lockA.unlock(); return null; } }); firstThreadResult.get(DEADLOCK_TIMEOUT_SECONDS * 3, TimeUnit.SECONDS); secondThreadResult.get(DEADLOCK_TIMEOUT_SECONDS * 3, TimeUnit.SECONDS); } /** * Verifies that factories do not deadlock each other. * * <pre> * Thread A: lock a lock A (factory A) * Thread B: lock a lock B (factory B) * Thread A: lock a lock B (factory B) * Thread B: lock a lock A (factory A) * </pre> * * <p>This should succeed even though from the point of view of each individual factory there are * no deadlocks to detect. */ public void testCycleDetectingLockFactoriesDoNotDeadlock() throws Exception { final CycleDetectingLockFactory<String> factoryA = new CycleDetectingLockFactory<String>(); final CycleDetectingLock<String> lockA = factoryA.create("A"); final CycleDetectingLockFactory<String> factoryB = new CycleDetectingLockFactory<String>(); final CycleDetectingLock<String> lockB = factoryB.create("B"); final CyclicBarrier eachThreadAcquiredFirstLock = new CyclicBarrier(2); Future<Boolean> threadA = Executors.newSingleThreadExecutor() .submit( new Callable<Boolean>() { @Override public Boolean call() throws Exception { Thread.currentThread().setName("A"); assertTrue(lockA.lockOrDetectPotentialLocksCycle().isEmpty()); eachThreadAcquiredFirstLock.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS); boolean isEmpty = lockB.lockOrDetectPotentialLocksCycle().isEmpty(); if (isEmpty) { lockB.unlock(); } lockA.unlock(); return isEmpty; } }); Future<Boolean> threadB = Executors.newSingleThreadExecutor() .submit( new Callable<Boolean>() { @Override public Boolean call() throws Exception { Thread.currentThread().setName("B"); assertTrue(lockB.lockOrDetectPotentialLocksCycle().isEmpty()); eachThreadAcquiredFirstLock.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS); boolean isEmpty = lockA.lockOrDetectPotentialLocksCycle().isEmpty(); if (isEmpty) { lockA.unlock(); } lockB.unlock(); return isEmpty; } }); boolean deadlockADetected = threadA.get(DEADLOCK_TIMEOUT_SECONDS * 2, TimeUnit.SECONDS); boolean deadlockBDetected = threadB.get(DEADLOCK_TIMEOUT_SECONDS * 2, TimeUnit.SECONDS); assertTrue("Deadlock should get detected", deadlockADetected || deadlockBDetected); assertTrue("One deadlock should get detected", deadlockADetected != deadlockBDetected); } /** * Verifies that factories deadlocks report the correct cycles. * * <pre> * Thread 1: takes locks a, b * Thread 2: takes locks b, c * Thread 3: takes locks c, a * </pre> * * <p>In order to ensure a deadlock, each thread will wait on a barrier right after grabbing the * first lock. */ public void testCycleReporting() throws Exception { final CycleDetectingLockFactory<String> factory = new CycleDetectingLockFactory<String>(); final CycleDetectingLock<String> lockA = factory.create("a"); final CycleDetectingLock<String> lockB = factory.create("b"); final CycleDetectingLock<String> lockC = factory.create("c"); final CyclicBarrier barrier = new CyclicBarrier(3); ImmutableList<Future<ListMultimap<Thread, String>>> futures = ImmutableList.of( grabLocksInThread(lockA, lockB, barrier), grabLocksInThread(lockB, lockC, barrier), grabLocksInThread(lockC, lockA, barrier)); // At least one of the threads will report a lock cycle, it is possible that they all will, but // there is no guarantee, so we just scan for the first thread that reported a cycle ListMultimap<Thread, String> cycle = null; for (Future<ListMultimap<Thread, String>> future : futures) { ListMultimap<Thread, String> value = future.get(DEADLOCK_TIMEOUT_SECONDS * 3, TimeUnit.SECONDS); if (!value.isEmpty()) { cycle = value; break; } } // We don't really care about the keys in the multimap, but we want to make sure that all locks // were reported in the right order. assertEquals(6, cycle.size()); Collection<List<String>> edges = Multimaps.asMap(cycle).values(); assertTrue(edges.contains(ImmutableList.of("a", "b"))); assertTrue(edges.contains(ImmutableList.of("b", "c"))); assertTrue(edges.contains(ImmutableList.of("c", "a"))); } private static <T> Future<ListMultimap<Thread, T>> grabLocksInThread( final CycleDetectingLock<T> lock1, final CycleDetectingLock<T> lock2, final CyclicBarrier barrier) { FutureTask<ListMultimap<Thread, T>> future = new FutureTask<ListMultimap<Thread, T>>( new Callable<ListMultimap<Thread, T>>() { @Override public ListMultimap<Thread, T> call() throws Exception { assertTrue(lock1.lockOrDetectPotentialLocksCycle().isEmpty()); barrier.await(); ListMultimap<Thread, T> cycle = lock2.lockOrDetectPotentialLocksCycle(); if (cycle == null) { lock2.unlock(); } lock1.unlock(); return cycle; } }); Thread thread = new Thread(future); thread.start(); return future; } }