/** * Licensed to Apereo under one or more contributor license agreements. See the NOTICE file * distributed with this work for additional information regarding copyright ownership. Apereo * licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use * this file except in compliance with the License. You may obtain a copy of the License at the * following location: * * <p>http://www.apache.org/licenses/LICENSE-2.0 * * <p>Unless required by applicable law or agreed to in writing, software distributed under the * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either * express or implied. See the License for the specific language governing permissions and * limitations under the License. */ package org.apereo.portal.test; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; import junit.framework.AssertionFailedError; import org.apereo.portal.utils.ConcurrentMapUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Utility for running several threads in a test * */ public final class ThreadGroupRunner { private enum State { SETUP, RUNNING, COMPLETE; } private final Logger logger = LoggerFactory.getLogger(getClass()); private final List<Thread> threads = new LinkedList<Thread>(); private final Map<String, Throwable> uncaughtExceptions = new ConcurrentHashMap<String, Throwable>(); private final String namePrefix; private final boolean daemon; private volatile State running = State.SETUP; public ThreadGroupRunner(String namePrefix, boolean daemon) { this.namePrefix = namePrefix; this.daemon = daemon; } /** Add a Runnable that will be executed in its own thread. */ public synchronized void addTask(Runnable r) { if (running != State.SETUP) { throw new IllegalStateException("Can't be called after start() has been called"); } final Thread t = new Thread(r, this.namePrefix + threads.size()); t.setDaemon(this.daemon); t.setUncaughtExceptionHandler( new Thread.UncaughtExceptionHandler() { @Override public void uncaughtException(Thread t, Throwable e) { logger.debug("Uncaught Exception", e); uncaughtExceptions.put(t.getName(), e); } }); threads.add(t); } public synchronized void addTask(int threadCount, Runnable r) { if (running != State.SETUP) { throw new IllegalStateException("Can't be called after start() has been called"); } for (int index = 0; index < threadCount; index++) { this.addTask(r); } } /** Start all threads (start is in order of added runnable) */ public synchronized void start() { if (running != State.SETUP) { throw new IllegalStateException("Can only be called once"); } running = State.RUNNING; for (final Thread t : this.threads) { t.start(); } } /** Join on all threads (join is in order of added runnable) */ public synchronized void join() throws InterruptedException { if (running != State.RUNNING) { throw new IllegalStateException("Can only be called after start()"); } for (final Thread t : this.threads) { t.join(); } running = State.COMPLETE; Map.Entry<String, Throwable> exception = null; for (final Map.Entry<String, Throwable> exceptionEntry : uncaughtExceptions.entrySet()) { if (exception == null) { exception = exceptionEntry; } logger.error( "Thread " + exceptionEntry.getKey() + " failed with an exception", exceptionEntry.getValue()); } if (exception != null) { final AssertionFailedError assertionError = new AssertionFailedError( "Thread " + exception.getKey() + " failed with an exception"); assertionError.initCause(exception.getValue()); throw assertionError; } } private final ConcurrentMap<Integer, CountDownLatch> latchMap = new ConcurrentHashMap<Integer, CountDownLatch>(); /** * Effectively a count down latch where all threads in the group must reach the specified tick * before any are allowed to proceed */ public void tick(int index) throws InterruptedException { tick(index, false); } /** * Effectively a count down latch where all threads in the group must reach the specified tick * before any are allowed to proceed. * * @param includeMainThread If true all threads in the group AND the main thread must call tick */ public void tick(int index, boolean includeMainThread) throws InterruptedException { if (running != State.RUNNING) { throw new IllegalStateException( "Can only be called after start() and before join() returns"); } CountDownLatch latch = latchMap.get(index); if (latch == null) { final int latchCount = this.threads.size() + (includeMainThread ? 1 : 0); final CountDownLatch newLatch = new CountDownLatch(latchCount); latch = ConcurrentMapUtils.putIfAbsent(latchMap, index, newLatch); if (newLatch == latch) { logger.debug("created tick({}) = {}", index, latchCount); } } latch.countDown(); logger.debug("tick({}) = {}", index, latch.getCount()); latch.await(); } }