package org.infinispan.test.concurrent; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import org.infinispan.util.logging.Log; import org.infinispan.util.logging.LogFactory; import net.jcip.annotations.GuardedBy; /** * Defines a set of logical threads, each with a list of states, and a partial ordering between states. * <p/> * <p>Logical threads are defined with {@link #logicalThread(String, String, String...)}. States in a logical thread are implicitly * ordered - they must be entered in the order in which they were defined.</p> * <p>The ordering between states in different logical threads can be defined with {@link #order(String, String, String...)}</p> * <p>A state can also have an associated action, defined with {@link #action(String, java.util.concurrent.Callable)}. * States that depend on another state with an associated action can only be entered after the action has finished.</p> * <p>Entering a state with {@link #enter(String)} will block until all the other states it depends on have been exited * with {@link #exit(String)}.</p> * * @author Dan Berindei * @since 7.0 */ public class StateSequencer { private static final Log log = LogFactory.getLog(StateSequencer.class); private final Map<String, LogicalThread> logicalThreads = new HashMap<String, LogicalThread>(); private final Map<String, State> stateMap = new HashMap<String, State>(); private final Lock lock = new ReentrantLock(); private final Condition condition = lock.newCondition(); private final long defaultTimeoutNanos; private boolean running = true; public StateSequencer() { this(30, TimeUnit.SECONDS); } public StateSequencer(long defaultTimeout, TimeUnit unit) { this.defaultTimeoutNanos = unit.toNanos(defaultTimeout); } /** * Define a logical thread. * <p/> * States in a logical thread are implicitly ordered - they must be entered in the order in which they were defined. */ public StateSequencer logicalThread(String threadName, String initialState, String... additionalStates) { lock.lock(); try { if (logicalThreads.containsKey(threadName)) { throw new IllegalArgumentException("Logical thread " + threadName + " already exists"); } List<String> states; if (additionalStates == null) { states = Collections.singletonList(initialState); } else { states = new ArrayList<String>(additionalStates.length + 1); states.add(initialState); states.addAll(Arrays.asList(additionalStates)); } LogicalThread thread = new LogicalThread(threadName, states); logicalThreads.put(threadName, thread); for (String stateName : states) { if (stateMap.containsKey(stateName)) { throw new IllegalArgumentException("State " + stateName + " already exists"); } State state = new State(threadName, stateName); stateMap.put(stateName, state); } doOrder(states); log.tracef("Added logical thread %s, with states %s", threadName, states); } finally { lock.unlock(); } return this; } private void doOrder(List<String> orderedStates) { lock.lock(); try { for (int i = 0; i < orderedStates.size(); i++) { State state = stateMap.get(orderedStates.get(i)); if (state == null) { throw new IllegalArgumentException("Cannot order a non-existing state: " + orderedStates.get(i)); } if (i > 0) { state.dependencies.add(orderedStates.get(i - 1)); } } verifyCycles(); log.tracef("Order changed: %s", getOrderString()); } finally { lock.unlock(); } } @GuardedBy("lock") private void verifyCycles() { visitInOrder(new StatesVisitor() { @Override public void visitStates(List<String> visitedStates) { // Do nothing } @Override public void visitCycle(Collection<String> remainingStates) { throw new IllegalStateException("Cycle detected: " + remainingStates); } }); } private String getOrderString() { final StringBuilder sb = new StringBuilder(); visitInOrder(new StatesVisitor() { @Override public void visitStates(List<String> visitedStates) { if (sb.length() > 1) { sb.append(" < "); } if (visitedStates.size() == 1) { sb.append(visitedStates.get(0)); } else { sb.append(visitedStates); } } @Override public void visitCycle(Collection<String> remainingStates) { sb.append("cycle: ").append(remainingStates); } }); return sb.toString(); } @GuardedBy("lock") private void visitInOrder(StatesVisitor visitor) { Set<String> visitedStates = new HashSet<String>(); Set<String> remainingStates = new HashSet<String>(stateMap.keySet()); while (!remainingStates.isEmpty()) { // In every iteration, we visit the states for which we already visited all their dependencies. // If there are no such states, it means we found a cycle. List<String> freeStates = new ArrayList<String>(); for (Iterator<String> it = remainingStates.iterator(); it.hasNext(); ) { State s = stateMap.get(it.next()); if (visitedStates.containsAll(s.dependencies)) { freeStates.add(s.name); it.remove(); } } visitedStates.addAll(freeStates); if (freeStates.size() != 0) { visitor.visitStates(freeStates); } else { visitor.visitCycle(remainingStates); } } } /** * Define a partial order between states in different logical threads. */ public StateSequencer order(String state1, String state2, String... additionalStates) { List<String> allStates; if (additionalStates == null) { allStates = new ArrayList<String>(Arrays.asList(state1, state2)); } else { allStates = new ArrayList<String>(additionalStates.length + 2); allStates.add(state1); allStates.add(state2); allStates.addAll(Arrays.asList(additionalStates)); } doOrder(allStates); return this; } /** * Define an action for a state. * <p/> * States that depend on another state with an associated action can only be entered after the action has finished. */ public StateSequencer action(String stateName, Callable<Object> action) { lock.lock(); try { State state = stateMap.get(stateName); if (state == null) { throw new IllegalArgumentException("Trying to add an action for an invalid state: " + stateName); } if (state.action != null) { throw new IllegalStateException("Trying to overwrite an existing action for state " + stateName); } state.action = action; log.tracef("Action added for state %s", stateName); } finally { lock.unlock(); } return this; } /** * Equivalent to {@code enter(state, timeout, unit); exit(state);}. */ public void advance(String state, long timeout, TimeUnit unit) throws TimeoutException, InterruptedException { enter(state, timeout, unit); exit(state); } /** * Enter a state and block until all its dependencies have been exited. */ public void enter(String stateName, long timeout, TimeUnit unit) throws TimeoutException, InterruptedException { doEnter(stateName, unit.toNanos(timeout)); } /** * Exit a state and signal the waiters on its dependent states. */ public void exit(String stateName) { log.tracef("Exiting state %s", stateName); lock.lock(); try { if (!running) return; State state = stateMap.get(stateName); if (state.signalled) { throw new IllegalStateException(String.format("State %s exited twice", stateName)); } state.signalled = true; condition.signalAll(); } finally { lock.unlock(); } } private void doEnter(String stateName, long nanos) throws InterruptedException, TimeoutException { lock.lock(); try { State state = stateMap.get(stateName); if (state == null) { throw new IllegalArgumentException("Trying to advance to a non-existing state: " + stateName); } if (!running) { log.tracef("Sequencer stopped, not entering state %s", stateName); return; } log.tracef("Waiting for states %s to enter %s", state.dependencies, stateName); for (String dependency : state.dependencies) { State depState = stateMap.get(dependency); nanos = waitForState(depState, nanos); if (nanos <= 0 && !depState.signalled) { reportTimeout(state); } } log.tracef("Entering state %s", stateName); logicalThreads.get(state.threadName).setCurrentState(stateName); if (state.action != null) { try { state.action.call(); } catch (Exception e) { throw new RuntimeException("Action failed for state " + stateName, e); } } } finally { lock.unlock(); } } @GuardedBy("lock") private long waitForState(State state, long nanos) throws InterruptedException { while (running && !state.signalled && nanos > 0L) { nanos = condition.awaitNanos(nanos); } return nanos; } @GuardedBy("lock") private void reportTimeout(State state) throws TimeoutException { List<String> timedOutStates = new ArrayList<String>(1); for (String dependencyName : state.dependencies) { State dependency = stateMap.get(dependencyName); if (!dependency.signalled) { timedOutStates.add(dependencyName); } } String errorMessage = String.format("Timed out waiting to enter state %s. Dependencies not satisfied are %s", state.name, timedOutStates); log.trace(errorMessage); throw new TimeoutException(errorMessage); } /** * Equivalent to {@code enter(state); exit(state);}. */ public void advance(String state) throws TimeoutException, InterruptedException { enter(state); exit(state); } /** * Enter a state and block until all its dependencies have been exited, using the default timeout. */ public void enter(String stateName) throws TimeoutException, InterruptedException { doEnter(stateName, defaultTimeoutNanos); } /** * Stop doing anything on {@code enter()} or {@code exit()}. * Existing threads waiting in {@code enter()} will be waken up. */ public void stop() { lock.lock(); try { log.tracef("Stopping sequencer %s", toString()); running = false; condition.signalAll(); } finally { lock.unlock(); } } public String toString() { lock.lock(); try { StringBuilder sb = new StringBuilder(); sb.append("Sequencer{ "); for (LogicalThread thread : logicalThreads.values()) { sb.append(thread); sb.append("; "); } sb.append("global order: ").append(getOrderString()); sb.append("}"); return sb.toString(); } finally { lock.unlock(); } } private interface StatesVisitor { void visitStates(List<String> visitedStates); void visitCycle(Collection<String> remainingStates); } private static class State { final String threadName; final String name; final List<String> dependencies; Callable<Object> action; boolean signalled; public State(String threadName, String name) { this.threadName = threadName; this.name = name; this.dependencies = new ArrayList<String>(); } } private static class LogicalThread { final String name; final List<String> states; String currentState; public LogicalThread(String name, List<String> states) { this.name = name; this.states = states; } public void setCurrentState(String state) { this.currentState = state; } public String toString() { StringBuilder sb = new StringBuilder(); sb.append(name).append(": "); for (int i = 0; i < states.size(); i++) { String state = states.get(i); if (i > 0) { sb.append(" < "); } if (state.equals(currentState)) { sb.append("*"); } sb.append(state); } return sb.toString(); } } }