/* * Copyright 2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.statemachine.support; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import org.junit.Test; import org.springframework.core.task.SyncTaskExecutor; import org.springframework.messaging.Message; import org.springframework.messaging.support.MessageBuilder; import org.springframework.scheduling.concurrent.ConcurrentTaskScheduler; import org.springframework.statemachine.StateContext; import org.springframework.statemachine.StateMachine; import org.springframework.statemachine.state.State; import org.springframework.statemachine.support.StateMachineExecutor.StateMachineExecutorTransit; import org.springframework.statemachine.transition.Transition; import org.springframework.statemachine.trigger.EventTrigger; import org.springframework.statemachine.trigger.TimerTrigger; import org.springframework.statemachine.trigger.Trigger; public class DefaultStateMachineExecutorTests { @SuppressWarnings("unchecked") @Test public void testSimpleExecute() throws Exception { SyncTaskExecutor taskExecutor = new SyncTaskExecutor(); Message<String> message = MessageBuilder.withPayload("E1").build(); EventTrigger<String, String> triggerE1 = new EventTrigger<String, String>("E1"); State<String, String> stateS1 = mock(State.class); when(stateS1.getId()).thenReturn("S1"); when(stateS1.getIds()).thenReturn(Arrays.asList("S1")); State<String, String> stateS2 = mock(State.class); when(stateS2.getId()).thenReturn("S2"); when(stateS2.getIds()).thenReturn(Arrays.asList("S2")); Transition<String, String> transitionS1S2 = mock(Transition.class); when(transitionS1S2.getSource()).thenReturn(stateS1); when(transitionS1S2.getTarget()).thenReturn(stateS2); when(transitionS1S2.getTrigger()).thenReturn(triggerE1); when(transitionS1S2.transit(any())).thenReturn(true); StateMachine<String, String> stateMachine = mock(StateMachine.class); when(stateMachine.getState()).thenReturn(stateS1); Collection<Transition<String, String>> transitions = new ArrayList<>(); transitions.add(transitionS1S2); Map<Trigger<String, String>, Transition<String, String>> triggerToTransitionMap = new HashMap<>(); triggerToTransitionMap.put(triggerE1, transitionS1S2); List<Transition<String, String>> triggerlessTransitions = new ArrayList<>(); Transition<String, String> initialTransition = mock(Transition.class); Message<String> initialEvent = null; DefaultStateMachineExecutor<String, String> executor = new DefaultStateMachineExecutor<>( stateMachine, stateMachine, transitions, triggerToTransitionMap, triggerlessTransitions, initialTransition, initialEvent); executor.setTaskExecutor(taskExecutor); TestStateMachineExecutorTransit transit = new TestStateMachineExecutorTransit(); transit.reset(2); executor.setStateMachineExecutorTransit(transit); executor.start(); executor.queueEvent(message); executor.execute(); assertThat(transit.latch.await(2, TimeUnit.SECONDS), is(true)); assertThat(transit.transitions.size(), is(2)); } @SuppressWarnings("unchecked") @Test public void testSimpleTimer() throws Exception { SyncTaskExecutor taskExecutor = new SyncTaskExecutor(); ConcurrentTaskScheduler taskScheduler = new ConcurrentTaskScheduler(); EventTrigger<String, String> triggerE1 = new EventTrigger<String, String>("E1"); TimerTrigger<String, String> triggerTimer = new TimerTrigger<>(1000, 1); triggerTimer.setTaskScheduler(taskScheduler); State<String, String> stateS1 = mock(State.class); when(stateS1.getId()).thenReturn("S1"); when(stateS1.getIds()).thenReturn(Arrays.asList("S1")); State<String, String> stateS2 = mock(State.class); when(stateS1.getId()).thenReturn("S2"); when(stateS1.getIds()).thenReturn(Arrays.asList("S2")); State<String, String> stateS3 = mock(State.class); when(stateS1.getId()).thenReturn("S3"); when(stateS1.getIds()).thenReturn(Arrays.asList("S3")); Transition<String, String> transitionS1S2 = mock(Transition.class); when(transitionS1S2.getSource()).thenReturn(stateS1); when(transitionS1S2.getTarget()).thenReturn(stateS2); when(transitionS1S2.getTrigger()).thenReturn(triggerE1); when(transitionS1S2.transit(any())).thenReturn(true); Transition<String, String> transitionS1S3 = mock(Transition.class); when(transitionS1S3.getSource()).thenReturn(stateS1); when(transitionS1S3.getTarget()).thenReturn(stateS3); when(transitionS1S3.getTrigger()).thenReturn(triggerTimer); when(transitionS1S3.transit(any())).thenReturn(true); StateMachine<String, String> stateMachine = mock(StateMachine.class); when(stateMachine.getState()).thenReturn(stateS1); Collection<Transition<String, String>> transitions = new ArrayList<>(); transitions.add(transitionS1S2); Map<Trigger<String, String>, Transition<String, String>> triggerToTransitionMap = new HashMap<>(); triggerToTransitionMap.put(triggerE1, transitionS1S2); triggerToTransitionMap.put(triggerTimer, transitionS1S3); List<Transition<String, String>> triggerlessTransitions = new ArrayList<>(); Transition<String, String> initialTransition = mock(Transition.class); Message<String> initialEvent = null; DefaultStateMachineExecutor<String, String> executor = new DefaultStateMachineExecutor<>( stateMachine, stateMachine, transitions, triggerToTransitionMap, triggerlessTransitions, initialTransition, initialEvent); executor.setTaskExecutor(taskExecutor); TestStateMachineExecutorTransit transit = new TestStateMachineExecutorTransit(); transit.reset(2); executor.setStateMachineExecutorTransit(transit); executor.start(); triggerTimer.start(); triggerTimer.arm(); assertThat(transit.latch.await(2, TimeUnit.SECONDS), is(true)); assertThat(transit.transitions.size(), is(2)); } @SuppressWarnings("unchecked") @Test public void testDeadlock() throws Exception { // gh-315 // nasty, with deadlock you can't use junit timeout // as then test is run on different thread, thus test doesn't fail. SyncTaskExecutor taskExecutor = new SyncTaskExecutor(); ConcurrentTaskScheduler taskScheduler = new ConcurrentTaskScheduler(); EventTrigger<String, String> triggerE1 = new EventTrigger<String, String>("E1"); TimerTrigger<String, String> triggerTimer = new TimerTrigger<>(1000); triggerTimer.setTaskScheduler(taskScheduler); State<String, String> stateS1 = mock(State.class); when(stateS1.getId()).thenReturn("S1"); when(stateS1.getIds()).thenReturn(Arrays.asList("S1")); State<String, String> stateS2 = mock(State.class); when(stateS1.getId()).thenReturn("S2"); when(stateS1.getIds()).thenReturn(Arrays.asList("S2")); State<String, String> stateS3 = mock(State.class); when(stateS1.getId()).thenReturn("S3"); when(stateS1.getIds()).thenReturn(Arrays.asList("S3")); Transition<String, String> transitionS1S2 = mock(Transition.class); when(transitionS1S2.getSource()).thenReturn(stateS1); when(transitionS1S2.getTarget()).thenReturn(stateS2); when(transitionS1S2.getTrigger()).thenReturn(triggerE1); when(transitionS1S2.transit(any())).thenReturn(true); Transition<String, String> transitionS1S3 = mock(Transition.class); when(transitionS1S3.getSource()).thenReturn(stateS1); when(transitionS1S3.getTarget()).thenReturn(stateS3); when(transitionS1S3.getTrigger()).thenReturn(triggerTimer); when(transitionS1S3.transit(any())).thenReturn(true); StateMachine<String, String> stateMachine = mock(StateMachine.class); when(stateMachine.getState()).thenReturn(stateS1); Collection<Transition<String, String>> transitions = new ArrayList<>(); transitions.add(transitionS1S2); Map<Trigger<String, String>, Transition<String, String>> triggerToTransitionMap = new HashMap<>(); triggerToTransitionMap.put(triggerE1, transitionS1S2); triggerToTransitionMap.put(triggerTimer, transitionS1S3); List<Transition<String, String>> triggerlessTransitions = new ArrayList<>(); Transition<String, String> initialTransition = mock(Transition.class); Message<String> initialEvent = null; DefaultStateMachineExecutor<String, String> executor = new DefaultStateMachineExecutor<>( stateMachine, stateMachine, transitions, triggerToTransitionMap, triggerlessTransitions, initialTransition, initialEvent); executor.setTaskExecutor(taskExecutor); TestStateMachineExecutorTransit transit = new TestStateMachineExecutorTransit(); transit.reset(2); executor.setStateMachineExecutorTransit(transit); executor.start(); triggerTimer.start(); assertThat(transit.latch.await(2, TimeUnit.SECONDS), is(true)); assertThat(transit.transitions.size(), is(2)); } private static class TestStateMachineExecutorTransit implements StateMachineExecutorTransit<String, String> { ArrayList<Transition<String, String>> transitions = new ArrayList<>(); CountDownLatch latch = new CountDownLatch(1); @Override public void transit(Transition<String, String> transition, StateContext<String, String> stateContext, Message<String> message) { transitions.add(transition); latch.countDown(); } void reset(int i) { latch = new CountDownLatch(i); transitions.clear(); } } }