package com.acuitra.pipeline;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ParallelPipelineRunner<T, O> {
private boolean running = false;
Logger logger = LoggerFactory.getLogger(this.getClass());
private List<RunnablePipeline<T, O>> pipelines = new ArrayList<>();
private Map<String, Map<String,O>> outputs = new HashMap<>();
private int timeoutMillis;
public ParallelPipelineRunner(int timeoutMillis) {
super();
this.timeoutMillis = timeoutMillis;
}
public void addPipeline(RunnablePipeline<T, O> pipeline) {
if (running) {
throw new IllegalStateException("Cannot add when already running");
}
pipelines.add(pipeline);
}
public void run() {
if (running) {
throw new IllegalStateException("Already running");
}
running = true;
try {
ExecutorService executor = Executors.newFixedThreadPool(pipelines.size());
for (RunnablePipeline<T, O> pipe : pipelines) {
executor.execute(pipe);
}
executor.shutdown();
try {
// Wait until all threads are finish
if (!executor.awaitTermination(getTimeoutMillis(), TimeUnit.MILLISECONDS)) {
logger.error("Pipeline execution took too long");
}
for (RunnablePipeline<T, O> pipe : pipelines) {
outputs.put(pipe.getName(), pipe.getContext().getPreviousOutputs());
}
} catch (InterruptedException e1) {
logger.error("Error", e1);
}
} finally {
running = false;
}
}
public Map<String, Map<String,O>> getOutputs() {
return outputs;
}
public int getTimeoutMillis() {
return timeoutMillis;
}
}