/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.teradata.tempto.threads; import com.google.common.collect.ImmutableList; import java.util.List; import static com.google.common.collect.Lists.newArrayList; import static java.lang.System.currentTimeMillis; import static java.util.Collections.synchronizedList; import static java.util.stream.Collectors.toList; /** * A class implementing parallel execution of code blocks. */ public class ParallelExecution { private final List<Thread> threads; private final List<Throwable> throwables = synchronizedList(newArrayList()); private ParallelExecution(List<IndexedRunnable> runnables) { threads = asThreads(runnables); } public ParallelExecution start() { threads.stream().forEach(Thread::start); return this; } /** * Joins all child threads and throws {@link ParallelExecutionException} if some * child throws a {@link Throwable}. * * @throws InterruptedException if the thread is interrupted */ public void joinAndRethrow() throws InterruptedException { joinAndRethrow(0L); } /** * @param timeout Milliseconds * @throws InterruptedException if the thread is interrupted * @return true if child threads were successfully joined within given timeout. */ public boolean joinAndRethrow(long timeout) throws InterruptedException { boolean joinedWithinTimeout = join(timeout); if (!throwables.isEmpty()) { throw new ParallelExecutionException(throwables); } return joinedWithinTimeout; } public void join() throws InterruptedException { join(0L); } public boolean join(long timeout) throws InterruptedException { if (timeout == 0) { timeout = Long.MAX_VALUE; } long startTime = currentTimeMillis(); for (Thread thread : threads) { long remaining = calculateRemainingTime(startTime, timeout); if (remaining > 0) { thread.join(remaining); } } return calculateRemainingTime(startTime, timeout) > 0; } private long calculateRemainingTime(long startTime, long timeout) { long elapsed = currentTimeMillis() - startTime; return timeout - elapsed; } /** * @return {@link Throwable}s that were caught in child threads during execution. */ public List<Throwable> getThrowables() { return throwables; } private List<Thread> asThreads(List<IndexedRunnable> runnables) { List<Thread> threads = newArrayList(); for (int i = 0; i < runnables.size(); ++i) { final int threadIndex = i; threads.add(new Thread(() -> { try { runnables.get(threadIndex).run(threadIndex); } catch (Throwable throwable) { throwables.add(throwable); } })); } return threads; } public static ParallelExecution parallelExecution(int nTimes, IndexedRunnable indexedRunnable) { return builder().addRunnable(nTimes, indexedRunnable).build(); } public static ParallelExecutionBuilder builder() { return new ParallelExecutionBuilder(); } public static class ParallelExecutionBuilder { private final List<IndexedRunnable> indexedRunnables = newArrayList(); private final List<Runnable> runnables = newArrayList(); public ParallelExecutionBuilder addRunnable(IndexedRunnable indexedRunnable) { return addRunnable(1, indexedRunnable); } public ParallelExecutionBuilder addRunnable(int nTimes, IndexedRunnable indexedRunnable) { for (int i = 0; i < nTimes; ++i) { indexedRunnables.add(indexedRunnable); } return this; } public ParallelExecutionBuilder addRunnable(Runnable runnable) { runnables.add(runnable); return this; } public ParallelExecution build() { List<IndexedRunnable> allIndexedRunnables = ImmutableList.<IndexedRunnable>builder() .addAll(indexedRunnables) .addAll(asParallelRunnables(runnables)) .build(); return new ParallelExecution(allIndexedRunnables); } } private static List<IndexedRunnable> asParallelRunnables(List<Runnable> runnables) { return runnables .stream() .map((Runnable runnable) -> (IndexedRunnable) (int threadIndex) -> runnable.run()) .collect(toList()); } }