/*
* Copyright 2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.statemachine.recipes.tasks;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.task.TaskExecutor;
import org.springframework.messaging.Message;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.statemachine.StateContext;
import org.springframework.statemachine.StateMachine;
import org.springframework.statemachine.StateMachineContext;
import org.springframework.statemachine.StateMachineException;
import org.springframework.statemachine.StateMachinePersist;
import org.springframework.statemachine.access.StateMachineAccess;
import org.springframework.statemachine.access.StateMachineFunction;
import org.springframework.statemachine.action.Action;
import org.springframework.statemachine.config.StateMachineBuilder;
import org.springframework.statemachine.config.builders.StateMachineStateConfigurer;
import org.springframework.statemachine.config.builders.StateMachineTransitionConfigurer;
import org.springframework.statemachine.guard.Guard;
import org.springframework.statemachine.listener.AbstractCompositeListener;
import org.springframework.statemachine.recipes.support.RunnableAction;
import org.springframework.statemachine.state.PseudoStateKind;
import org.springframework.statemachine.state.State;
import org.springframework.statemachine.support.DefaultStateMachineContext;
import org.springframework.statemachine.support.StateMachineInterceptor;
import org.springframework.statemachine.support.StateMachineInterceptorAdapter;
import org.springframework.statemachine.support.StateMachineUtils;
import org.springframework.statemachine.support.tree.Tree;
import org.springframework.statemachine.support.tree.Tree.Node;
import org.springframework.statemachine.support.tree.TreeTraverser;
import org.springframework.statemachine.transition.Transition;
/**
* {@code TasksHandler} is a recipe for executing arbitrary {@link Runnable} tasks
* using a state machine logic.
*
* This recipe supports execution of multiple top-level tasks with a
* sub-states construct of DAGs.
*
* @author Janne Valkealahti
*
*/
public class TasksHandler {
private final static Log log = LogFactory.getLog(TasksHandler.class);
public final static String STATE_READY = "READY";
public final static String STATE_FORK = "FORK";
public final static String STATE_TASKS = "TASKS";
public final static String STATE_JOIN = "JOIN";
public final static String STATE_CHOICE = "CHOICE";
public final static String STATE_ERROR = "ERROR";
public final static String STATE_AUTOMATIC = "AUTOMATIC";
public final static String STATE_MANUAL = "MANUAL";
public final static String STATE_TASKS_PREFIX = "TASK_";
public final static String STATE_TASKS_INITIAL_POSTFIX = "_INITIAL";
public final static String EVENT_RUN = "RUN";
public final static String EVENT_FALLBACK = "FALLBACK";
public final static String EVENT_CONTINUE = "CONTINUE";
public final static String EVENT_FIX = "FIX";
private StateMachine<String, String> stateMachine;
private final CompositeTasksListener listener = new CompositeTasksListener();
private final StateMachinePersist<String, String, Void> persist;
/**
* Instantiates a new tasks handler. Intentionally private instantiation
* meant to be called from a builder.
*
* @param tasks the wrapped tasks
* @param listener the tasks listener
* @param taskExecutor the task executor
* @param persist the state machine persist
*/
private TasksHandler(List<TaskWrapper> tasks, TasksListener listener, TaskExecutor taskExecutor,
StateMachinePersist<String, String, Void> persist) {
this.persist = persist;
try {
stateMachine = buildStateMachine(tasks, taskExecutor);
if (persist != null) {
final LocalStateMachineInterceptor interceptor = new LocalStateMachineInterceptor(persist);
stateMachine.getStateMachineAccessor()
.doWithAllRegions(new StateMachineFunction<StateMachineAccess<String, String>>() {
@Override
public void apply(StateMachineAccess<String, String> function) {
function.addStateMachineInterceptor(interceptor);
}
});
}
} catch (Exception e) {
throw new StateMachineException("Error building state machine from tasks", e);
}
if (listener != null) {
addTasksListener(listener);
}
}
/**
* Request to execute current tasks logic.
*/
public void runTasks() {
stateMachine.sendEvent(EVENT_RUN);
}
/**
* Request to continue from an error.
*/
public void continueFromError() {
stateMachine.sendEvent(EVENT_CONTINUE);
}
/**
* Request to fix current problems.
*/
public void fixCurrentProblems() {
stateMachine.sendEvent(EVENT_FIX);
}
/**
* Resets state machine states from a backing persistent repository. If
* {@link StateMachinePersist} is not set this method doesn't do anything.
* {@link StateMachine} is stopped before states are reseted from a persistent
* store and started afterwards.
*/
public void resetFromPersistStore() {
if (persist == null) {
// TODO: should we throw or silently return?
return;
}
final StateMachineContext<String, String> context;
try {
context = persist.read(null);
} catch (Exception e) {
throw new StateMachineException("Error reading state from persistent store", e);
}
stateMachine.stop();
stateMachine.getStateMachineAccessor()
.doWithAllRegions(new StateMachineFunction<StateMachineAccess<String, String>>() {
@Override
public void apply(StateMachineAccess<String, String> function) {
function.resetStateMachine(context);
}
});
stateMachine.start();
}
/**
* Adds the tasks listener.
*
* @param listener the listener
*/
public void addTasksListener(TasksListener listener) {
this.listener.register(listener);
}
/**
* Removes the tasks listener.
*
* @param listener the listener
*/
public void removeTasksListener(TasksListener listener) {
this.listener.unregister(listener);
}
/**
* Gets the internal state machine used by executing tasks.
*
* @return the state machine
*/
public StateMachine<String, String> getStateMachine() {
return stateMachine;
}
/**
* Gets a new instance of a {@link Builder} which is used to build
* an instance of a {@code TasksHandler}.
*
* @return the tasks handler builder
*/
public static Builder builder() {
return new Builder();
}
/**
* Mark all extended state variables related to tasks fixed.
*/
public void markAllTasksFixed() {
Map<Object, Object> variables = getStateMachine().getExtendedState().getVariables();
for (Entry<Object, Object> entry : variables.entrySet()) {
if (entry.getKey() instanceof String && ((String)entry.getKey()).startsWith(STATE_TASKS_PREFIX)) {
if (entry.getValue() instanceof Integer) {
Integer value = (Integer) entry.getValue();
if (value < 0) {
variables.put(entry.getKey(), 0);
}
}
}
}
}
private StateMachine<String, String> buildStateMachine(List<TaskWrapper> tasks, TaskExecutor taskExecutor)
throws Exception {
StateMachineBuilder.Builder<String, String> builder = StateMachineBuilder.builder();
int taskCount = topLevelTaskCount(tasks);
builder.configureConfiguration().withConfiguration()
.taskExecutor(taskExecutor != null ? taskExecutor : taskExecutor(taskCount));
StateMachineStateConfigurer<String, String> stateMachineStateConfigurer = builder.configureStates();
StateMachineTransitionConfigurer<String, String> stateMachineTransitionConfigurer = builder.configureTransitions();
stateMachineStateConfigurer
.withStates()
.initial(STATE_READY)
.fork(STATE_FORK)
.state(STATE_TASKS, tasksEntryAction(), null)
.join(STATE_JOIN)
.choice(STATE_CHOICE)
.state(STATE_ERROR);
stateMachineTransitionConfigurer
.withExternal()
.source(STATE_READY).target(STATE_FORK).event(EVENT_RUN)
.and()
.withFork()
.source(STATE_FORK).target(STATE_TASKS);
Iterator<Node<TaskWrapper>> iterator = buildTasksIterator(tasks);
String parent = null;
Collection<String> joinStates = new ArrayList<String>();
while (iterator.hasNext()) {
Node<TaskWrapper> node = iterator.next();
if (node.getData() == null) {
break;
}
String initial = STATE_TASKS_PREFIX + node.getData().id.toString() + STATE_TASKS_INITIAL_POSTFIX;
String task = STATE_TASKS_PREFIX + node.getData().id.toString();
parent = node.getData().parent != null ? STATE_TASKS_PREFIX + node.getData().parent.toString() : STATE_TASKS;
stateMachineStateConfigurer
.withStates()
.parent(parent)
.initial(initial)
.state(task, runnableAction(node.getData().runnable, node.getData().id.toString()), null);
if (node.getChildren().isEmpty()) {
joinStates.add(task);
}
stateMachineTransitionConfigurer
.withExternal()
.state(parent)
.source(initial)
.target(task);
}
stateMachineStateConfigurer
.withStates()
.parent(STATE_ERROR)
.initial(STATE_AUTOMATIC)
.state(STATE_AUTOMATIC, automaticAction(), null)
.state(STATE_MANUAL);
stateMachineTransitionConfigurer
.withJoin()
.sources(joinStates)
.target(STATE_JOIN)
.and()
.withExternal()
.source(STATE_JOIN).target(STATE_CHOICE)
.and()
.withChoice()
.source(STATE_CHOICE)
.first(STATE_ERROR, tasksChoiceGuard())
.last(STATE_READY)
.and()
.withExternal()
.source(STATE_ERROR).target(STATE_READY)
.event(EVENT_CONTINUE)
.action(continueAction())
.and()
.withExternal()
.source(STATE_AUTOMATIC).target(STATE_MANUAL)
.event(EVENT_FALLBACK)
.and()
.withInternal()
.source(STATE_MANUAL)
.action(fixAction())
.event(EVENT_FIX);
return builder.build();
}
private static TaskExecutor taskExecutor(int taskCount) {
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
taskExecutor.afterPropertiesSet();
taskExecutor.setCorePoolSize(taskCount);
return taskExecutor;
}
private static int topLevelTaskCount(List<TaskWrapper> tasks) {
Tree<TaskWrapper> tree = new Tree<TaskWrapper>();
for (TaskWrapper wrapper : tasks) {
tree.add(wrapper, wrapper.id, wrapper.parent);
}
return tree.getRoot().getChildren().size();
}
private static Iterator<Node<TaskWrapper>> buildTasksIterator(List<TaskWrapper> tasks) {
Tree<TaskWrapper> tree = new Tree<TaskWrapper>();
for (TaskWrapper wrapper : tasks) {
tree.add(wrapper, wrapper.id, wrapper.parent);
}
TreeTraverser<Node<TaskWrapper>> traverser = new TreeTraverser<Node<TaskWrapper>>() {
@Override
public Iterable<Node<TaskWrapper>> children(Node<TaskWrapper> root) {
return root.getChildren();
}
};
Iterable<Node<TaskWrapper>> postOrderTraversal = traverser.postOrderTraversal(tree.getRoot());
Iterator<Node<TaskWrapper>> iterator = postOrderTraversal.iterator();
return iterator;
}
/**
* Builder pattern implementation building a {@link TasksHandler}.
*/
public static class Builder {
private final List<TaskWrapper> tasks = new ArrayList<TaskWrapper>();
private TasksListener listener;
private TaskExecutor taskExecutor;
private StateMachinePersist<String, String, Void> persist;
/**
* Define a top-level task.
*
* @param id the id
* @param runnable the runnable
* @return the builder for chaining
*/
public Builder task(Object id, Runnable runnable) {
tasks.add(new TaskWrapper(null, id, runnable));
return this;
}
/**
* Define a sub-task with a reference to its parent.
*
* @param parent the parent
* @param id the id
* @param runnable the runnable
* @return the builder for chaining
*/
public Builder task(Object parent, Object id, Runnable runnable) {
tasks.add(new TaskWrapper(parent, id, runnable));
return this;
}
/**
* Define a {@link StateMachinePersist} implementation if state machine
* should be persisted with state changes.
*
* @param persist the persist
* @return the builder for chaining
*/
public Builder persist(StateMachinePersist<String, String, Void> persist) {
this.persist = persist;
return this;
}
/**
* Define a {@link TasksListener} to be registered.
*
* @param listener the tasks listener
* @return the builder for chaining
*/
public Builder listener(TasksListener listener) {
this.listener = listener;
return this;
}
/**
* Define a {@link TaskExecutor} to be used. Default executor will be
* a {@link ThreadPoolTaskExecutor} set with a thread pool size of
* a top-level task count.
*
* @param taskExecutor the task executor
* @return the builder for chaining
*/
public Builder taskExecutor(TaskExecutor taskExecutor) {
this.taskExecutor = taskExecutor;
return this;
}
/**
* Builds the {@link TasksHandler}.
*
* @return the tasks handler
*/
public TasksHandler build() {
return new TasksHandler(tasks, listener, taskExecutor, persist);
}
}
/**
* Gets a tasks entry action.
*
* @return the tasks entry action
*/
private TasksEntryAction tasksEntryAction() {
return new TasksEntryAction();
}
/**
* Gets a local runnable action.
*
* @param runnable the runnable
* @param id the task id
* @return the local runnable action
*/
private LocalRunnableAction runnableAction(Runnable runnable, String id) {
return new LocalRunnableAction(runnable, id);
}
/**
* Tasks choice guard. This {@link Guard} will check if related
* extended state variables contains negative values for related
* tasks id's and returns true if so, else false.
*
* @return the guard
*/
private Guard<String, String> tasksChoiceGuard() {
return new Guard<String, String>() {
@Override
public boolean evaluate(StateContext<String, String> context) {
Map<Object, Object> variables = context.getExtendedState().getVariables();
for (Entry<Object, Object> entry : variables.entrySet()) {
if (entry.getKey() instanceof String && ((String)entry.getKey()).startsWith(STATE_TASKS_PREFIX)) {
if (entry.getValue() instanceof Integer) {
Integer value = (Integer) entry.getValue();
if (value < 0) {
if (log.isDebugEnabled()) {
log.debug("Task id=[" + entry.getKey() + "] has negative execution value, tasksChoiceGuard returns true");
}
listener.onTasksError();
return true;
}
}
}
}
listener.onTasksSuccess();
return false;
}
};
}
/**
* {@link Action} which simply sends an event of continue
* tasks into a state machine.
*
* @return the action
*/
private Action<String, String> continueAction() {
return new Action<String, String>() {
@Override
public void execute(StateContext<String, String> context) {
listener.onTasksContinue();
}
};
}
/**
* {@link Action} calls {@link TasksListener#onTasksAutomaticFix(StateContext)}
* before checking status of extended state variables related to tasks. If all
* variables are ok, event {@code EVENT_CONTINUE} is sent, otherwise event
* {@code EVENT_FALLBACK} is send which takes state machine into a manual handling.
*
* @return the action
*/
private Action<String, String> automaticAction() {
return new Action<String, String>() {
@Override
public void execute(StateContext<String, String> context) {
listener.onTasksAutomaticFix(TasksHandler.this, context);
boolean hasErrors = false;
Map<Object, Object> variables = context.getExtendedState().getVariables();
for (Entry<Object, Object> entry : variables.entrySet()) {
if (entry.getKey() instanceof String && ((String)entry.getKey()).startsWith(STATE_TASKS_PREFIX)) {
if (entry.getValue() instanceof Integer) {
Integer value = (Integer) entry.getValue();
if (value < 0) {
hasErrors = true;
break;
}
}
}
}
if (hasErrors) {
context.getStateMachine().sendEvent(EVENT_FALLBACK);
} else {
context.getStateMachine().sendEvent(EVENT_CONTINUE);
}
}
};
}
/**
* {@link Action} which resets related extended state variables
* to zero for tasks order to indicate a fixed tasks.
*
* @return the action
*/
private Action<String, String> fixAction() {
return new Action<String, String>() {
@Override
public void execute(StateContext<String, String> context) {
Map<Object, Object> variables = context.getExtendedState().getVariables();
for (Entry<Object, Object> entry : variables.entrySet()) {
if (entry.getKey() instanceof String && ((String)entry.getKey()).startsWith(STATE_TASKS_PREFIX)) {
if (entry.getValue() instanceof Integer) {
Integer value = (Integer) entry.getValue();
if (value < 0) {
variables.put(entry.getKey(), 0);
}
}
}
}
}
};
}
/**
* Adapter class for {@link TasksListener}.
*/
public static class TasksListenerAdapter implements TasksListener {
@Override
public void onTasksStarted() {
}
@Override
public void onTasksContinue() {
}
@Override
public void onTaskPreExecute(Object id) {
}
@Override
public void onTaskPostExecute(Object id) {
}
@Override
public void onTaskFailed(Object id, Exception exception) {
}
@Override
public void onTaskSuccess(Object id) {
}
@Override
public void onTasksSuccess() {
}
@Override
public void onTasksError() {
}
@Override
public void onTasksAutomaticFix(TasksHandler handler, StateContext<String, String> context) {
}
}
/**
* {@code TasksListener} is a generic interface listening tasks
* execution events. Methods in this interface will be called in a
* tasks execution position where user most likely will want to get
* notified.
*/
public interface TasksListener {
/**
* Called when all DAGs have either never executed or previous
* execution was fully successful.
*/
void onTasksStarted();
/**
* Called when some of a tasks in DAGs failed to execute and tasks
* execution in going to continue.
*/
void onTasksContinue();
/**
* Called before tasks is about to be executed.
*
* @param id the task id
*/
void onTaskPreExecute(Object id);
/**
* Called after tasks has been executed regardless if task
* execution succeeded or not.
*
* @param id the task id
*/
void onTaskPostExecute(Object id);
/**
* Called when task execution result an error of any kind.
*
* @param id the task id
* @param exception the exception
*/
void onTaskFailed(Object id, Exception exception);
/**
* Called when task execution result without errors.
*
* @param id the task id
*/
void onTaskSuccess(Object id);
/**
* Called when all tasks has been executed successfully.
*/
void onTasksSuccess();
/**
* Called when after an execution of full DAGs if some of the
* tasks executed with an error.
*/
void onTasksError();
/**
* Called when tasks execution resulted an error and AUTOMATIC state
* is entered. This is a moment where extended state variables can be
* modified to allow continue into a READY state.
*
* @param handler the tasks handler
* @param context the state context
*/
void onTasksAutomaticFix(TasksHandler handler, StateContext<String, String> context);
}
private class CompositeTasksListener extends AbstractCompositeListener<TasksListener> implements
TasksListener {
@Override
public void onTasksStarted() {
for (Iterator<TasksListener> iterator = getListeners().reverse(); iterator.hasNext();) {
iterator.next().onTasksStarted();
}
}
@Override
public void onTasksContinue() {
for (Iterator<TasksListener> iterator = getListeners().reverse(); iterator.hasNext();) {
iterator.next().onTasksContinue();
}
}
@Override
public void onTaskPreExecute(Object id) {
for (Iterator<TasksListener> iterator = getListeners().reverse(); iterator.hasNext();) {
iterator.next().onTaskPreExecute(id);
}
}
@Override
public void onTaskPostExecute(Object id) {
for (Iterator<TasksListener> iterator = getListeners().reverse(); iterator.hasNext();) {
iterator.next().onTaskPostExecute(id);
}
}
@Override
public void onTaskFailed(Object id, Exception exception) {
for (Iterator<TasksListener> iterator = getListeners().reverse(); iterator.hasNext();) {
iterator.next().onTaskFailed(id, exception);
}
}
@Override
public void onTaskSuccess(Object id) {
for (Iterator<TasksListener> iterator = getListeners().reverse(); iterator.hasNext();) {
iterator.next().onTaskSuccess(id);
}
}
@Override
public void onTasksSuccess() {
for (Iterator<TasksListener> iterator = getListeners().reverse(); iterator.hasNext();) {
iterator.next().onTasksSuccess();
}
}
@Override
public void onTasksError() {
for (Iterator<TasksListener> iterator = getListeners().reverse(); iterator.hasNext();) {
iterator.next().onTasksError();
}
}
@Override
public void onTasksAutomaticFix(TasksHandler handler, StateContext<String, String> context) {
for (Iterator<TasksListener> iterator = getListeners().reverse(); iterator.hasNext();) {
iterator.next().onTasksAutomaticFix(handler, context);
}
}
}
/**
* {@link Action} which is executed when TASKS state is entered.
*/
private class TasksEntryAction implements Action<String, String> {
@Override
public void execute(StateContext<String, String> context) {
boolean hasErrors = false;
Map<Object, Object> variables = context.getExtendedState().getVariables();
for (Entry<Object, Object> entry : variables.entrySet()) {
if (entry.getKey() instanceof String && ((String)entry.getKey()).startsWith(STATE_TASKS_PREFIX)) {
if (entry.getValue() instanceof Integer) {
Integer value = (Integer) entry.getValue();
if (value < 0) {
hasErrors = true;
break;
}
}
}
}
if (hasErrors) {
listener.onTasksContinue();
} else {
listener.onTasksStarted();
}
}
}
/**
* {@link Action} which is executed with every registered {@link Runnable}.
*/
private class LocalRunnableAction extends RunnableAction {
public LocalRunnableAction(Runnable runnable, String id) {
super(runnable, id);
}
@Override
protected boolean shouldExecute(String id, StateContext<String, String> context) {
return super.shouldExecute(id, context);
}
@Override
protected void onPreExecute(String id, StateContext<String, String> context) {
listener.onTaskPreExecute(id);
}
@Override
protected void onPostExecute(String id, StateContext<String, String> context) {
listener.onTaskPostExecute(id);
}
@Override
protected void onSuccess(String id, StateContext<String, String> context) {
listener.onTaskSuccess(id);
changeCount(1, context);
}
@Override
protected void onError(String id, StateContext<String, String> context, Exception e) {
listener.onTaskFailed(id, e);
changeCount(-1, context);
}
private void changeCount(int delta, StateContext<String, String> context) {
Map<Object, Object> variables = context.getExtendedState().getVariables();
Integer count;
String key = STATE_TASKS_PREFIX + getId();
if (variables.containsKey(key)) {
count = (Integer) variables.get(key);
} else {
count = 0;
}
count =+ delta;
variables.put(key, count);
}
}
/**
* Local {@link StateMachineInterceptor} persisting state machine states.
*/
private class LocalStateMachineInterceptor extends StateMachineInterceptorAdapter<String, String> {
// TODO: should try to find a common way to build context and
// not do tweaks here.
private final StateMachinePersist<String, String, Void> persist;
private DefaultStateMachineContext<String, String> currentContext;
private State<String, String> currentContextState;
private final List<StateMachineContext<String, String>> childs = new ArrayList<StateMachineContext<String, String>>();
public LocalStateMachineInterceptor(StateMachinePersist<String, String, Void> persist) {
this.persist = persist;
}
@Override
public void preStateChange(State<String, String> state, Message<String> message,
Transition<String, String> transition, StateMachine<String, String> stateMachine) {
// skip all other pseudostates than initial
if (state == null || (state.getPseudoState() != null && state.getPseudoState().getKind() != PseudoStateKind.INITIAL)) {
return;
}
// track root state here and update childs
if (currentContext != null && StateMachineUtils.isSubstate(currentContextState, state)) {
DefaultStateMachineContext<String, String> context = new DefaultStateMachineContext<String, String>(
transition != null ? transition.getTarget().getId() : null, message != null ? message.getPayload()
: null, message != null ? message.getHeaders() : null, stateMachine.getExtendedState());
currentContext.getChilds().add(context);
} else {
childs.clear();
DefaultStateMachineContext<String, String> context = new DefaultStateMachineContext<String, String>(
new ArrayList<StateMachineContext<String, String>>(childs), state.getId(), message != null ? message.getPayload()
: null, message != null ? message.getHeaders() : null, stateMachine.getExtendedState());
currentContext = context;
currentContextState = state;
}
try {
persist.write(currentContext, null);
} catch (Exception e) {
throw new StateMachineException("Error persisting", e);
}
}
}
/**
* Wrapping a {@link Runnable} with a task identifier and parent if task
* is a subtask. If parent is null it indicates that a task is a top-level
* task with optional child tasks creating a dag task graph.
*/
private static class TaskWrapper {
final Object parent;
final Object id;
final Runnable runnable;
public TaskWrapper(Object parent, Object id, Runnable runnable) {
this.parent = parent;
this.id = id;
this.runnable = runnable;
}
}
}