package org.embulk.exec;
import java.util.List;
import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ExecutionException;
import com.google.common.base.Throwables;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import com.google.inject.Inject;
import org.embulk.config.ConfigSource;
import org.embulk.config.TaskSource;
import org.embulk.config.TaskReport;
import org.embulk.spi.Exec;
import org.embulk.spi.ExecSession;
import org.embulk.spi.ExecutorPlugin;
import org.embulk.spi.ProcessTask;
import org.embulk.spi.ProcessState;
import org.embulk.spi.Schema;
import org.embulk.spi.InputPlugin;
import org.embulk.spi.FilterPlugin;
import org.embulk.spi.OutputPlugin;
import org.embulk.spi.Page;
import org.embulk.spi.PageOutput;
import org.embulk.spi.AbortTransactionResource;
import org.embulk.spi.CloseResource;
import org.embulk.spi.TransactionalPageOutput;
import org.embulk.plugin.compat.PluginWrappers;
import org.embulk.spi.util.Filters;
import org.embulk.spi.util.Executors;
import org.embulk.spi.util.Executors.ProcessStateCallback;
public class LocalExecutorPlugin
implements ExecutorPlugin
{
private int defaultMaxThreads;
private int defaultMinThreads;
@Inject
public LocalExecutorPlugin(@ForSystemConfig ConfigSource systemConfig)
{
int cores = Runtime.getRuntime().availableProcessors();
this.defaultMaxThreads = systemConfig.get(Integer.class, "max_threads", cores * 2);
this.defaultMinThreads = systemConfig.get(Integer.class, "min_output_tasks", cores);
}
@Override
public void transaction(ConfigSource config, Schema outputSchema, int inputTaskCount,
ExecutorPlugin.Control control)
{
try (AbstractLocalExecutor exec = newExecutor(config, inputTaskCount)) {
control.transaction(outputSchema, exec.getOutputTaskCount(), exec);
}
}
private AbstractLocalExecutor newExecutor(ConfigSource config, int inputTaskCount)
{
Logger log = Exec.getLogger(LocalExecutorPlugin.class);
int maxThreads = config.get(Integer.class, "max_threads", defaultMaxThreads);
int minThreads = config.get(Integer.class, "min_output_tasks", defaultMinThreads);
if (inputTaskCount > 0 && inputTaskCount < minThreads) {
int scatterCount = (minThreads + inputTaskCount - 1) / inputTaskCount;
log.info("Using local thread executor with max_threads={} / output tasks {} = input tasks {} * {}",
maxThreads, inputTaskCount * scatterCount, inputTaskCount, scatterCount);
return new ScatterExecutor(maxThreads, inputTaskCount, scatterCount);
}
else {
log.info("Using local thread executor with max_threads={} / tasks={}", maxThreads, inputTaskCount);
return new DirectExecutor(maxThreads, inputTaskCount);
}
}
private static abstract class AbstractLocalExecutor
implements Executor, AutoCloseable
{
protected final Logger log = Exec.getLogger(LocalExecutorPlugin.class);
protected final int inputTaskCount;
protected final int outputTaskCount;
public AbstractLocalExecutor(int inputTaskCount, int outputTaskCount)
{
this.inputTaskCount = inputTaskCount;
this.outputTaskCount = outputTaskCount;
}
public int getOutputTaskCount()
{
return outputTaskCount;
}
@Override
public void execute(ProcessTask task, ProcessState state)
{
state.initialize(inputTaskCount, outputTaskCount);
List<Future<Throwable>> futures = new ArrayList<>(inputTaskCount);
try {
for (int i = 0; i < inputTaskCount; i++) {
futures.add(startInputTask(task, state, i));
}
showProgress(state, inputTaskCount);
for (int i = 0; i < inputTaskCount; i++) {
if (futures.get(i) == null) {
continue;
}
try {
state.getInputTaskState(i).setException(futures.get(i).get());
}
catch (ExecutionException ex) {
state.getInputTaskState(i).setException(ex.getCause());
//Throwables.propagate(ex.getCause());
}
catch (InterruptedException ex) {
state.getInputTaskState(i).setException(new ExecutionInterruptedException(ex));
}
showProgress(state, inputTaskCount);
}
}
finally {
for (Future<Throwable> future : futures) {
if (future != null && !future.isDone()) {
future.cancel(true);
// TODO join?
}
}
}
}
@Override
public abstract void close();
private void showProgress(ProcessState state, int taskCount)
{
int started = 0;
int finished = 0;
for (int i = 0; i < taskCount; i++) {
if (state.getOutputTaskState(i).isStarted()) { started++; }
if (state.getOutputTaskState(i).isFinished()) { finished++; }
}
log.info(String.format("{done:%3d / %d, running: %d}", finished, taskCount, started - finished));
}
protected abstract Future<Throwable> startInputTask(ProcessTask task, ProcessState state, int taskIndex);
}
public static class DirectExecutor
extends AbstractLocalExecutor
{
protected final ExecutorService executor;
public DirectExecutor(int maxThreads, int taskCount)
{
super(taskCount, taskCount);
this.executor = java.util.concurrent.Executors.newFixedThreadPool(maxThreads,
new ThreadFactoryBuilder()
.setNameFormat("embulk-executor-%d")
.setDaemon(true)
.build());
}
@Override
public void close()
{
executor.shutdown();
}
@Override
protected Future<Throwable> startInputTask(final ProcessTask task, final ProcessState state, final int taskIndex)
{
if (state.getOutputTaskState(taskIndex).isCommitted()) {
log.warn("Skipped resumed task {}", taskIndex);
return null; // resumed
}
return executor.submit(new Callable<Throwable>() {
public Throwable call()
{
try (SetCurrentThreadName dontCare = new SetCurrentThreadName(String.format("task-%04d", taskIndex))) {
Executors.process(Exec.session(), task, taskIndex, new ProcessStateCallback() {
public void started()
{
state.getInputTaskState(taskIndex).start();
state.getOutputTaskState(taskIndex).start();
}
public void inputCommitted(TaskReport report)
{
state.getInputTaskState(taskIndex).setTaskReport(report);
}
public void outputCommitted(TaskReport report)
{
state.getOutputTaskState(taskIndex).setTaskReport(report);
}
});
return null;
}
finally {
state.getInputTaskState(taskIndex).finish();
state.getOutputTaskState(taskIndex).finish();
}
}
});
}
}
public static class ScatterExecutor
extends AbstractLocalExecutor
{
private final int scatterCount;
private final int inputTaskCount;
private final ExecutorService inputExecutor;
private final ExecutorService outputExecutor;
public ScatterExecutor(int maxThreads, int inputTaskCount, int scatterCount)
{
super(inputTaskCount, inputTaskCount * scatterCount);
this.inputTaskCount = inputTaskCount;
this.scatterCount = scatterCount;
this.inputExecutor = java.util.concurrent.Executors.newFixedThreadPool(
Math.max(maxThreads / scatterCount, 1),
new ThreadFactoryBuilder()
.setNameFormat("embulk-input-executor-%d")
.setDaemon(true)
.build());
this.outputExecutor = java.util.concurrent.Executors.newCachedThreadPool(
new ThreadFactoryBuilder()
.setNameFormat("embulk-output-executor-%d")
.setDaemon(true)
.build());
}
@Override
public void close()
{
inputExecutor.shutdown();
outputExecutor.shutdown();
}
@Override
protected Future<Throwable> startInputTask(final ProcessTask task, final ProcessState state, final int taskIndex)
{
if(isAllScatterOutputFinished(state, taskIndex)) {
log.warn("Skipped resumed input task {}", taskIndex);
return null; // resumed
}
return inputExecutor.submit(new Callable<Throwable>() {
public Throwable call()
{
try (SetCurrentThreadName dontCare = new SetCurrentThreadName(String.format("task-%04d", taskIndex))) {
runInputTask(Exec.session(), task, state, taskIndex);
return null;
}
}
});
}
private boolean isAllScatterOutputFinished(ProcessState state, int taskIndex) {
for (int i = 0; i < scatterCount; i++) {
int outputTaskIndex = taskIndex * scatterCount + i;
if (!state.getOutputTaskState(outputTaskIndex).isCommitted()) {
return false;
}
}
return true;
}
private void runInputTask(ExecSession exec, ProcessTask task, ProcessState state, int taskIndex)
{
InputPlugin inputPlugin = exec.newPlugin(InputPlugin.class, task.getInputPluginType());
List<FilterPlugin> filterPlugins = Filters.newFilterPlugins(exec, task.getFilterPluginTypes());
OutputPlugin outputPlugin = exec.newPlugin(OutputPlugin.class, task.getOutputPluginType());
try (ScatterTransactionalPageOutput tran = new ScatterTransactionalPageOutput(state, taskIndex, scatterCount)) {
tran.openOutputs(outputPlugin, task.getOutputSchema(), task.getOutputTaskSource());
try (AbortTransactionResource aborter = new AbortTransactionResource(tran)) {
tran.openFilters(filterPlugins, task.getFilterSchemas(), task.getFilterTaskSources());
tran.startWorkers(outputExecutor);
// started
state.getInputTaskState(taskIndex).start();
for (int i = 0; i < scatterCount; i++) {
state.getOutputTaskState(taskIndex * scatterCount + i).start();
}
TaskReport inputTaskReport = inputPlugin.run(task.getInputTaskSource(), task.getInputSchema(), taskIndex, tran);
// inputCommitted
if (inputTaskReport == null) {
inputTaskReport = exec.newTaskReport();
}
state.getInputTaskState(taskIndex).setTaskReport(inputTaskReport);
// outputCommitted
tran.commit();
aborter.dontAbort();
}
}
finally {
state.getInputTaskState(taskIndex).finish();
state.getOutputTaskState(taskIndex).finish();
}
}
}
private static class ScatterTransactionalPageOutput
implements TransactionalPageOutput
{
private static final Page DONE_PAGE = Page.allocate(0);
private static class OutputWorker
implements Callable<Throwable>
{
private final PageOutput output;
private final Future<Throwable> future;
private volatile int addWaiting;
private volatile Page queued;
public OutputWorker(PageOutput output, ExecutorService executor)
{
this.output = output;
this.addWaiting = 0;
this.future = executor.submit(this);
}
public synchronized void done()
throws InterruptedException
{
while (true) {
if (queued == null && addWaiting == 0) {
queued = DONE_PAGE;
notifyAll();
return;
}
else if (queued == DONE_PAGE) {
return;
}
wait();
}
}
public synchronized void add(Page page)
throws InterruptedException
{
addWaiting++;
try {
while (true) {
if (queued == null) {
queued = page;
notifyAll();
return;
}
else if (queued == DONE_PAGE) {
page.release();
return;
}
wait();
}
}
finally {
addWaiting--;
}
}
public Throwable join()
throws InterruptedException
{
try {
return future.get();
}
catch (ExecutionException ex) {
return ex.getCause();
}
}
@Override
public synchronized Throwable call()
throws InterruptedException
{
try {
while (true) {
if (queued != null) {
if (queued == DONE_PAGE) {
return null;
}
output.add(queued);
queued = null;
notifyAll();
}
wait();
}
}
finally {
try {
if (queued != null && queued != DONE_PAGE) {
queued.release();
queued = null;
}
}
finally {
queued = DONE_PAGE;
}
notifyAll();
}
}
}
private final ProcessState state;
private final int taskIndex;
private final int scatterCount;
private final TransactionalPageOutput[] trans;
private final PageOutput[] filtereds;
private final CloseResource[] closeThese;
private final OutputWorker[] outputWorkers;
private long pageCount;
public ScatterTransactionalPageOutput(ProcessState state, int taskIndex, int scatterCount)
{
this.state = state;
this.taskIndex = taskIndex;
this.scatterCount = scatterCount;
this.trans = new TransactionalPageOutput[scatterCount];
this.filtereds = new PageOutput[scatterCount];
this.closeThese = new CloseResource[scatterCount];
for (int i = 0; i < scatterCount; i++) {
closeThese[i] = new CloseResource();
}
this.outputWorkers = new OutputWorker[scatterCount];
}
public void openOutputs(OutputPlugin outputPlugin, Schema outputSchema, TaskSource outputTaskSource)
{
for (int i = 0; i < scatterCount; i++) {
int outputTaskIndex = taskIndex * scatterCount + i;
if (!state.getOutputTaskState(outputTaskIndex).isCommitted()) {
TransactionalPageOutput tran = PluginWrappers.transactionalPageOutput(
outputPlugin.open(outputTaskSource, outputSchema, outputTaskIndex));
trans[i] = tran;
closeThese[i].closeThis(tran);
}
}
}
public void openFilters(List<FilterPlugin> filterPlugins, List<Schema> filterSchemas, List<TaskSource> filterTaskSources)
{
for (int i = 0; i < scatterCount; i++) {
TransactionalPageOutput tran = trans[i];
if (tran != null) {
PageOutput filtered = Filters.open(filterPlugins, filterTaskSources, filterSchemas, trans[i]);
filtereds[i] = filtered;
closeThese[i].closeThis(filtered);
}
}
}
public void startWorkers(ExecutorService outputExecutor)
{
for (int i = 0; i < scatterCount; i++) {
PageOutput filtered = filtereds[i];
if (filtered != null) {
outputWorkers[i] = new OutputWorker(filtered, outputExecutor);
}
}
}
public void add(Page page)
{
OutputWorker worker = outputWorkers[(int) (pageCount % scatterCount)];
if (worker != null) {
try {
worker.add(page);
}
catch (InterruptedException ex) {
throw Throwables.propagate(ex);
}
}
pageCount++;
}
public void finish()
{
completeWorkers();
for (int i = 0; i < scatterCount; i++) {
if (filtereds[i] != null) {
filtereds[i].finish();
}
}
}
public void close()
{
completeWorkers();
for (int i = 0; i < scatterCount; i++) {
closeThese[i].close();
}
}
public void abort()
{
completeWorkers();
for (int i = 0; i < scatterCount; i++) {
if (trans[i] != null) {
trans[i].abort();
}
}
}
public TaskReport commit()
{
completeWorkers();
for (int i = 0; i < scatterCount; i++) {
if (trans[i] != null) {
int outputTaskIndex = taskIndex * scatterCount + i;
TaskReport outputTaskReport = trans[i].commit();
trans[i] = null; // don't abort
if (outputTaskReport == null) {
outputTaskReport = Exec.newTaskReport();
}
state.getOutputTaskState(outputTaskIndex).setTaskReport(outputTaskReport);
}
}
return null;
}
public void completeWorkers()
{
for (int i = 0; i < scatterCount; i++) {
OutputWorker worker = outputWorkers[i];
if (worker != null) {
try {
worker.done();
}
catch (InterruptedException ex) {
throw Throwables.propagate(ex);
}
Throwable error = null;
try {
error = worker.join();
}
catch (InterruptedException ex) {
error = ex;
}
outputWorkers[i] = null;
if (error != null) {
throw Throwables.propagate(error);
}
}
}
}
}
}