/* * Copyright 2017 ThoughtWorks, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.thoughtworks.go.server.dao; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertThat; public class ThreadSafetyChecker { private long testTimeoutTime; private List<Operation> operations; private ConcurrentMap<Thread, Throwable> exceptions; public ThreadSafetyChecker(long timeoutInMillisecondsForEveryThreadJoin) { this.testTimeoutTime = timeoutInMillisecondsForEveryThreadJoin; this.operations = new ArrayList<>(); this.exceptions = new ConcurrentHashMap<>(); } public void addOperation(Operation operation) { operations.add(operation); } public void run(final int numberOfTimesToRunTheStuffInsideTheOperations) throws Exception { Thread.setDefaultUncaughtExceptionHandler(new Thread.UncaughtExceptionHandler() { @Override public void uncaughtException(Thread thread, Throwable throwable) { exceptions.put(thread, throwable); } }); List<Thread> threads = createThreads(numberOfTimesToRunTheStuffInsideTheOperations); startThreads(threads); waitForThreadsToFinish(threads); assertThat(exceptions.toString(), exceptions.size(), is(0)); } private void waitForThreadsToFinish(List<Thread> threads) throws Exception { for (Thread thread : threads) { thread.join(testTimeoutTime); } } private List<Thread> createThreads(final int numberOfTimesToRunTheStuffInsideTheOperations) { List<Thread> threads = new ArrayList<>(); for (int i = 0; i < operations.size(); i++) { final Operation operation = operations.get(i); Thread thread = new Thread(new Runnable() { @Override public void run() { for (int runIndex = 0; runIndex < numberOfTimesToRunTheStuffInsideTheOperations; runIndex++) { operation.execute(runIndex); } } }, "ThreadSafetyChecker-" + i); threads.add(thread); } return threads; } private void startThreads(List<Thread> threads) { for (Thread thread : threads) { thread.start(); } } public static abstract class Operation { public abstract void execute(int runIndex); } }