/* * Hibernate Search, full-text search for your domain model * * License: GNU Lesser General Public License (LGPL), version 2.1 or later * See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>. */ package org.hibernate.search.testsupport.concurrency; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.junit.Assert; /** * Helper to create tests which need to execute multiple tasks at "approximately same time", * like stress tests. * Note that we assume the number of threads and tasks match. * * If any exception happens, a JUnit failure is caused. * * @author Sanne Grinovero (C) 2014 Red Hat Inc. */ public class ConcurrentRunner { public static final int DEFAULT_REPEAT = 300; public static final int DEFAULT_THREADS = 30; private final ConcurrentMap<Integer, Throwable> failures = new ConcurrentHashMap<>( 0 ); private final ExecutorService executor; private final CountDownLatch startLatch = new CountDownLatch( 1 ); private final CountDownLatch endLatch; private final TaskFactory factory; private final int repetitions; private Long timeoutValue; private TimeUnit timeoutUnit; /** * Provide a source for {@link Runnable} instances to run concurrently. * This is meant to simplify collection and creation of such tasks. * @param factory the source of Runnable instances to run */ public ConcurrentRunner(TaskFactory factory) { this( DEFAULT_REPEAT, DEFAULT_THREADS, factory ); } /** * /** * Provide a source for {@link Runnable} instances to run concurrently. * This is meant to simplify collection and creation of such tasks. * @param repetitions the amount of times the task should be repeated * @param threads the number of threads used to run the task in parallel * @param factory the source of Runnable instances to run. */ public ConcurrentRunner(int repetitions, int threads, TaskFactory factory) { this.repetitions = repetitions; this.factory = factory; executor = Executors.newFixedThreadPool( threads ); endLatch = new CountDownLatch( repetitions ); } public ConcurrentRunner setTimeout(long timeoutValue, TimeUnit timeoutUnit) { this.timeoutValue = timeoutValue; this.timeoutUnit = timeoutUnit; return this; } /** * Invokes the {@link TaskFactory} and runs all the built tasks in * an Executor. * @throws Exception if any exception is thrown during the creation of tasks. * @throws AssertionError if interrupted or any exception is thrown by the tasks. */ public void execute() throws Exception, AssertionError { for ( int i = 0; i < repetitions; i++ ) { Runnable userRunnable = factory.createRunnable( i ); executor.execute( new WrapRunnable( startLatch, endLatch, i, userRunnable ) ); } executor.shutdown(); startLatch.countDown(); boolean timedOut = false; try { if ( timeoutValue != null ) { if ( ! endLatch.await( timeoutValue, timeoutUnit ) ) { executor.shutdownNow(); timedOut = true; } } else { endLatch.await(); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); e.printStackTrace(); Assert.fail( "Interrupted while awaiting for end of execution" ); } AssertionError reportedError = null; if ( timedOut ) { reportedError = new AssertionError( "The thread pool didn't finish executing after " + timeoutValue + " " + timeoutUnit ); // Go on and also add errors (if any) as suppressed exceptions } for ( Map.Entry<Integer, Throwable> entry : failures.entrySet() ) { if ( reportedError == null ) { reportedError = new AssertionError( "Unexpected failure on task #" + entry.getKey(), entry.getValue() ); } else { reportedError.addSuppressed( entry.getValue() ); } } if ( reportedError != null ) { throw reportedError; } } private class WrapRunnable implements Runnable { private final CountDownLatch startLatch; private final CountDownLatch endLatch; private final Integer taskIndex; private final Runnable userRunnable; public WrapRunnable(CountDownLatch startLatch, CountDownLatch endLatch, Integer taskIndex, Runnable userRunnable) { this.startLatch = startLatch; this.endLatch = endLatch; this.taskIndex = taskIndex; this.userRunnable = userRunnable; } @Override public void run() { try { startLatch.await(); // Maximize chances of working concurrently on the Serializer //Prevent more work to be scheduled if something failed already if ( failures.isEmpty() ) { userRunnable.run(); } } catch (InterruptedException | RuntimeException | AssertionError e) { failures.put( taskIndex, e ); } endLatch.countDown(); } } public interface TaskFactory { Runnable createRunnable(int i) throws Exception; } }