package com.marklogic.client.batch; import com.marklogic.client.document.DocumentWriteOperation; import com.marklogic.client.helper.LoggingObject; import org.springframework.core.task.AsyncListenableTaskExecutor; import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.core.task.SyncTaskExecutor; import org.springframework.core.task.TaskExecutor; import org.springframework.scheduling.concurrent.ExecutorConfigurationSupport; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.util.concurrent.ListenableFutureCallback; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; /** * Support class for BatchWriter implementations that uses Spring's TaskExecutor interface for parallelizing writes to * MarkLogic. Allows for setting a TaskExecutor instance, and if one is not set, a default one will be created based * on the threadCount attribute. That attribute is ignored if a TaskExecutor is set. */ public abstract class BatchWriterSupport extends LoggingObject implements BatchWriter { private TaskExecutor taskExecutor; private int threadCount = 16; private WriteListener writeListener; @Override public void initialize() { if (taskExecutor == null) { initializeDefaultTaskExecutor(); } } @Override public void waitForCompletion() { if (taskExecutor instanceof ExecutorConfigurationSupport) { if (logger.isInfoEnabled()) { logger.info("Calling shutdown on thread pool"); } ((ExecutorConfigurationSupport) taskExecutor).shutdown(); if (logger.isInfoEnabled()) { logger.info("Thread pool finished shutdown"); } } } protected void initializeDefaultTaskExecutor() { if (threadCount > 1) { if (logger.isInfoEnabled()) { logger.info("Initializing thread pool with a count of " + threadCount); } ThreadPoolTaskExecutor tpte = new ThreadPoolTaskExecutor(); tpte.setCorePoolSize(threadCount); // By default, wait for tasks to finish, and wait up to an hour tpte.setWaitForTasksToCompleteOnShutdown(true); tpte.setAwaitTerminationSeconds(60 * 60); tpte.afterPropertiesSet(); this.taskExecutor = tpte; } else { if (logger.isInfoEnabled()) { logger.info("Thread count is 1, so using a synchronous TaskExecutor"); } this.taskExecutor = new SyncTaskExecutor(); } } /** * Will use the WriteListener if the TaskExecutor is an instance of AsyncListenableTaskExecutor. The WriteListener * will then be used to listen for failures. * * @param runnable * @param items */ protected void executeRunnable(Runnable runnable, final List<? extends DocumentWriteOperation> items) { if (writeListener != null && taskExecutor instanceof AsyncListenableTaskExecutor) { AsyncListenableTaskExecutor asyncListenableTaskExecutor = (AsyncListenableTaskExecutor)taskExecutor; ListenableFuture<?> future = asyncListenableTaskExecutor.submitListenable(runnable); future.addCallback(new ListenableFutureCallback<Object>() { @Override public void onFailure(Throwable ex) { writeListener.onWriteFailure(ex, items); } @Override public void onSuccess(Object result) { } }); } else { taskExecutor.execute(runnable); } } protected TaskExecutor getTaskExecutor() { return taskExecutor; } public void setTaskExecutor(TaskExecutor taskExecutor) { this.taskExecutor = taskExecutor; } public void setThreadCount(int threadCount) { this.threadCount = threadCount; } protected WriteListener getWriteListener() { return writeListener; } public void setWriteListener(WriteListener writeListener) { this.writeListener = writeListener; } }