/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.executor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import ml.shifu.shifu.util.Environment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Created by zhanhu on 12/12/16.
*/
public class ExecutorManager<T> {
private static Logger LOG = LoggerFactory.getLogger(ExecutorManager.class);
private ExecutorService executorService = null;
public ExecutorManager() {
this(Environment.getInt("shifu.combo.thread.parallel", 10));
}
public ExecutorManager(int threadPoolSize) {
this.executorService = Executors.newFixedThreadPool(threadPoolSize);
}
@SuppressWarnings("rawtypes")
public void submitTasksAndWaitFinish(List<Runnable> tasks) {
List<Future<?>> futureList = new ArrayList<Future<?>>(tasks.size());
for ( Runnable task : tasks ) {
Future<?> future = executorService.submit(task);
futureList.add(future);
}
for ( Future future : futureList ) {
try {
future.get();
} catch (InterruptedException e) {
LOG.error("Error occurred, when waiting task to finish.", e);
} catch (ExecutionException e) {
LOG.error("Error occurred, when waiting task to finish.", e);
}
}
return;
}
public List<T> submitTasksAndWaitResults(List<Callable<T>> tasks) {
List<T> results = new ArrayList<T>();
List<Future<T>> futureList = new ArrayList<Future<T>>(tasks.size());
for ( Callable<T> task : tasks ) {
Future<T> future = executorService.submit(task);
futureList.add(future);
}
for ( Future<T> future : futureList ) {
try {
results.add(future.get());
} catch (InterruptedException e) {
LOG.error("Error occurred, when waiting task to finish.", e);
} catch (ExecutionException e) {
LOG.error("Error occurred, when waiting task to finish.", e);
}
}
return results;
}
public List<Integer> submitTasksAndRetryIfFail(List<Callable<Integer>> tasks, int maxRetryTimes) {
List<Integer> results = new ArrayList<Integer>(tasks.size());
int[] taskLeftTryTimes = new int[tasks.size()];
Arrays.fill(taskLeftTryTimes, maxRetryTimes);
List<TaskFuture> taskFutures = new ArrayList<TaskFuture>();
for ( int i = 0; i < tasks.size(); i ++ ) {
Callable<Integer> task = tasks.get(i);
Future<Integer> future = executorService.submit(task);
taskFutures.add(new TaskFuture(i, future));
results.add(null);
}
int size = taskFutures.size();
int i = 0;
while ( i < size ) {
TaskFuture tf = taskFutures.get(i);
try {
Integer res = tf.getFuture().get();
if ( res == null || res != 0 ) {
if ( ! retryTask(tf, tasks, taskFutures, taskLeftTryTimes, maxRetryTimes) ) {
results.set(tf.getTaskId(), res);
}
} else {
results.set(tf.getTaskId(), res);
}
} catch (InterruptedException e) {
// don't retry, for it may be shutting down
} catch (ExecutionException e) {
if ( ! retryTask(tf, tasks, taskFutures, taskLeftTryTimes, maxRetryTimes) ) {
results.set(tf.getTaskId(), 1);
}
}
i ++;
size = taskFutures.size();
}
return results;
}
private boolean retryTask(TaskFuture tf,
List<Callable<Integer>> tasks,
List<TaskFuture> taskFutures,
int[] taskLeftTryTimes, int maxRetryTimes) {
taskLeftTryTimes[tf.getTaskId()] --;
if ( taskLeftTryTimes[tf.getTaskId()] > 0 ) {
int taskId = tf.getTaskId();
Callable<Integer> task = tasks.get(taskId);
Future<Integer> future = executorService.submit(task);
taskFutures.add(new TaskFuture(taskId, future));
LOG.warn("Retry task - {} with {}-th times.", taskId,
(maxRetryTimes - taskLeftTryTimes[tf.getTaskId()]));
return true;
} else {
return false;
}
}
public void graceShutDown() {
this.executorService.shutdown();
try {
this.executorService.awaitTermination(Integer.MAX_VALUE, TimeUnit.SECONDS);
} catch ( Exception e ) {
LOG.error("Error occurred, when waiting task to finish.", e);
}
}
public void forceShutDown() {
this.executorService.shutdownNow();
try {
this.executorService.awaitTermination(2, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
public static class TaskFuture {
private int taskId;
private Future<Integer> future;
public TaskFuture(int taskId, Future<Integer> future) {
this.taskId = taskId;
this.future = future;
}
public int getTaskId() {
return taskId;
}
public Future<Integer> getFuture() {
return future;
}
}
}