/*
* Copyright 2015
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.dkpro.lab.engine.impl;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dkpro.lab.engine.ExecutionException;
import org.dkpro.lab.engine.LifeCycleException;
import org.dkpro.lab.engine.TaskContext;
import org.dkpro.lab.engine.TaskExecutionEngine;
import org.dkpro.lab.engine.TaskExecutionService;
import org.dkpro.lab.storage.UnresolvedImportException;
import org.dkpro.lab.task.BatchTask;
import org.dkpro.lab.task.Task;
import org.dkpro.lab.task.TaskContextMetadata;
import org.springframework.beans.factory.annotation.Value;
public class MultiThreadBatchTaskEngine
extends BatchTaskEngine
{
private final Log log = LogFactory.getLog(getClass());
public static final String PROP_THREADS = "engine.batch.maxThreads";
@Value("#{ @Properties['" + PROP_THREADS + "'] }")
private int maxThreads = Runtime.getRuntime().availableProcessors() - 1;
/**
* Explicit no-args constructor
*/
public MultiThreadBatchTaskEngine()
{
// Nothing to do.
}
/**
* Constructor with number of threads.
*
* @param aNThreads
* The number of threads to use for the MultiThreadBatchTask.
*/
public MultiThreadBatchTaskEngine(int aNThreads)
{
setMaxThreads(aNThreads);
}
public void setMaxThreads(int aNThreads)
{
maxThreads = aNThreads;
}
@Override
protected void executeConfiguration(BatchTask aConfiguration, TaskContext aContext,
Map<String, Object> aConfig, Set<String> aExecutedSubtasks)
throws ExecutionException, LifeCycleException
{
if (log.isTraceEnabled()) {
// Show all subtasks executed so far
for (String est : aExecutedSubtasks) {
log.trace("-- Already executed: " + est);
}
}
// Set up initial scope used by sub-batch-tasks using the inherited scope. The scope is
// extended as the subtasks of this batch are executed with the present configuration.
// FIXME: That means that sub-batch-tasks in two different configurations cannot see
// each other. Is that intended? Mind that the "executedSubtasks" set is intentionally
// maintained *across* configurations, so maybe the scope should also be maintained
// *across* configurations? - REC 2014-06-15
Set<String> scope = new HashSet<>();
if (aConfiguration.getScope() != null) {
scope.addAll(aConfiguration.getScope());
}
// Configure subtasks
for (Task task : aConfiguration.getTasks()) {
// Now the setup is complete
aContext.getLifeCycleManager().configure(aContext, task, aConfig);
}
Queue<Task> queue = new LinkedList<>(aConfiguration.getTasks());
// keeps track of the execution threads;
// TODO MW: do we really need this or can we work with the futures list only?
Map<Task, ExecutionThread> threads = new HashMap<>();
// keeps track of submitted Futures and their associated tasks
Map<Future<?>, Task> futures = new HashMap<Future<?>, Task>();
// will be instantiated with all exceptions from current loop
ConcurrentMap<Task, Throwable> exceptionsFromLastLoop = null;
ConcurrentMap<Task, Throwable> exceptionsFromCurrentLoop = new ConcurrentHashMap<>();
int outerLoopCounter = 0;
// main loop
do {
outerLoopCounter++;
threads.clear();
futures.clear();
ExecutorService executor = Executors.newFixedThreadPool(maxThreads);
// set the exceptions from the last loop
exceptionsFromLastLoop = new ConcurrentHashMap<>(exceptionsFromCurrentLoop);
// Fix MW: Clear exceptionsFromCurrentLoop; otherwise the loop with run at most twice.
exceptionsFromCurrentLoop.clear();
// process all tasks from the queue
while (!queue.isEmpty()) {
Task task = queue.poll();
TaskContextMetadata execution = getExistingExecution(aConfiguration, aContext,
task, aConfig, aExecutedSubtasks);
// Check if a subtask execution compatible with the present configuration has
// does already exist ...
if (execution == null) {
// ... otherwise execute it with the present configuration
log.info("Executing task [" + task.getType() + "]");
// set scope here so that the inherited scopes are considered
if (task instanceof BatchTask) {
((BatchTask) task).setScope(scope);
}
ExecutionThread thread = new ExecutionThread(aContext, task, aConfig,
aExecutedSubtasks);
threads.put(task, thread);
futures.put(executor.submit(thread), task);
}
else {
log.debug("Using existing execution [" + execution.getId() + "]");
// Record new/existing execution
aExecutedSubtasks.add(execution.getId());
scope.add(execution.getId());
}
}
// try and get results from all futures to check for failed executions
for(Map.Entry<Future<?>, Task> entry : futures.entrySet()){
try {
entry.getKey().get();
}
catch(java.util.concurrent.ExecutionException ex) {
Task task = entry.getValue();
// TODO MW: add a retry-counter here to prevent endless loops?
log.info("Task exec failed for [" + task.getType() + "]");
// record the failed task, so that it can be re-added to the queue
exceptionsFromCurrentLoop.put(task, ex);
}
catch(InterruptedException ex){
// thread interrupted, exit
throw new RuntimeException(ex);
}
}
log.debug("Calling shutdown");
executor.shutdown();
log.debug("All threads finished");
// collect the results
for (Map.Entry<Task, ExecutionThread> entry : threads.entrySet()) {
Task task = entry.getKey();
ExecutionThread thread = entry.getValue();
TaskContextMetadata execution = thread.getTaskContextMetadata();
// probably failed
if (execution == null) {
Throwable exception = exceptionsFromCurrentLoop.get(task);
if (!(exception instanceof UnresolvedImportException)
&& !(exception instanceof java.util.concurrent.ExecutionException)) {
throw new RuntimeException(exception);
}
exceptionsFromCurrentLoop.put(task, exception);
// re-add to the queue
queue.add(task);
}
else {
// Record new/existing execution
aExecutedSubtasks.add(execution.getId());
scope.add(execution.getId());
}
}
}
// finish if the same tasks failed again
while (!exceptionsFromCurrentLoop.keySet().equals(exceptionsFromLastLoop.keySet()));
// END OF DO; finish if the same tasks failed again
if (!exceptionsFromCurrentLoop.isEmpty()) {
// collect all details
StringBuilder details = new StringBuilder();
for (Throwable throwable : exceptionsFromCurrentLoop.values()) {
details.append("\n -");
details.append(throwable.getMessage());
}
// we re-throw the first exception
Throwable next = exceptionsFromCurrentLoop.values().iterator().next();
if (next instanceof RuntimeException) {
throw (RuntimeException) next;
}
// otherwise wrap it
throw new RuntimeException(details.toString(), next);
}
log.info("MultiThreadBatchTask completed successfully. Total number of outer loop runs: "
+ outerLoopCounter);
}
/**
* Represents a task's execution thread,
* together with the associated context, config and scope.
*/
protected class ExecutionThread
extends Thread
{
private final TaskContext aContext;
private final Task task;
private final Map<String, Object> aConfig;
private final Set<String> scope;
private TaskContextMetadata taskContextMetadata;
public ExecutionThread(TaskContext aContext, Task aTask, Map<String, Object> aConfig,
Set<String> aScope)
{
this.aContext = aContext;
this.task = aTask;
this.aConfig = aConfig;
this.scope = aScope;
}
@Override public void run()
{
TaskExecutionService execService = aContext.getExecutionService();
TaskExecutionEngine engine = execService.createEngine(task);
engine.setContextFactory(new ScopedTaskContextFactory(execService
.getContextFactory(), aConfig, scope));
String uuid;
try {
uuid = engine.run(task);
}
catch (ExecutionException | LifeCycleException e) {
throw new RuntimeException(e);
}
taskContextMetadata = aContext.getStorageService().getContext(uuid);
}
/**
* Returns the result of the run.
*/
public TaskContextMetadata getTaskContextMetadata()
{
return taskContextMetadata;
}
}
}