package com.tngtech.archunit.junit; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import com.tngtech.archunit.Slow; import com.tngtech.archunit.core.importer.ImportOptions; import com.tngtech.archunit.junit.ClassCache.CacheClassFileImporter; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; import org.mockito.InjectMocks; import org.mockito.Spy; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; import static java.util.concurrent.TimeUnit.MINUTES; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyCollection; import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @Category(Slow.class) public class ClassCacheConcurrencyTest { private static final int NUM_THREADS = 20; private static final List<Class<?>> TEST_CLASSES = Arrays.asList( TestClass1.class, TestClass2.class, TestClass3.class, TestClass4.class, TestClass5.class, TestClass6.class ); @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); @Spy private CacheClassFileImporter classFileImporter; @InjectMocks private ClassCache cache = new ClassCache(); private final ExecutorService executorService = Executors.newFixedThreadPool(NUM_THREADS); @Test @SuppressWarnings("unchecked") public void concurrent_access() throws Exception { List<Future<?>> futures = new ArrayList<>(); for (int i = 0; i < NUM_THREADS; i++) { futures.add(executorService.submit(repeatGetClassesToAnalyze(1000))); } for (Future<?> future : futures) { future.get(1, MINUTES); } verify(classFileImporter, atMost(TEST_CLASSES.size())).importClasses(any(ImportOptions.class), anyCollection()); verifyNoMoreInteractions(classFileImporter); } private Runnable repeatGetClassesToAnalyze(final int times) { return new Runnable() { @Override public void run() { for (int j = 0; j < times; j++) { cache.getClassesToAnalyzeFor(TEST_CLASSES.get(j % TEST_CLASSES.size())); } } }; } @AnalyzeClasses(packages = "com.tngtech.archunit.junit") public static class TestClass1 { } @AnalyzeClasses(packages = "com.tngtech.archunit.example") public static class TestClass2 { } @AnalyzeClasses(packages = "com.tngtech.archunit.integration") public static class TestClass3 { } @AnalyzeClasses(packages = "com.tngtech.archunit.core") public static class TestClass4 { } @AnalyzeClasses(packages = "com.tngtech.archunit") public static class TestClass5 { } @AnalyzeClasses(packages = "com.tngtech") public static class TestClass6 { } }