/*
* Copyright 2011-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amazonaws.services.stepfunctions.builder.internal.validation;
import com.amazonaws.services.stepfunctions.builder.ErrorCodes;
import com.amazonaws.services.stepfunctions.builder.StateMachine;
import com.amazonaws.services.stepfunctions.builder.conditions.BinaryCondition;
import com.amazonaws.services.stepfunctions.builder.conditions.Condition;
import com.amazonaws.services.stepfunctions.builder.conditions.NAryCondition;
import com.amazonaws.services.stepfunctions.builder.conditions.NotCondition;
import com.amazonaws.services.stepfunctions.builder.internal.PropertyNames;
import com.amazonaws.services.stepfunctions.builder.states.Branch;
import com.amazonaws.services.stepfunctions.builder.states.Catcher;
import com.amazonaws.services.stepfunctions.builder.states.Choice;
import com.amazonaws.services.stepfunctions.builder.states.ChoiceState;
import com.amazonaws.services.stepfunctions.builder.states.FailState;
import com.amazonaws.services.stepfunctions.builder.states.NextStateTransition;
import com.amazonaws.services.stepfunctions.builder.states.ParallelState;
import com.amazonaws.services.stepfunctions.builder.states.PassState;
import com.amazonaws.services.stepfunctions.builder.states.Retrier;
import com.amazonaws.services.stepfunctions.builder.states.State;
import com.amazonaws.services.stepfunctions.builder.states.StateVisitor;
import com.amazonaws.services.stepfunctions.builder.states.SucceedState;
import com.amazonaws.services.stepfunctions.builder.states.TaskState;
import com.amazonaws.services.stepfunctions.builder.states.Transition;
import com.amazonaws.services.stepfunctions.builder.states.TransitionState;
import com.amazonaws.services.stepfunctions.builder.states.WaitFor;
import com.amazonaws.services.stepfunctions.builder.states.WaitForSeconds;
import com.amazonaws.services.stepfunctions.builder.states.WaitForSecondsPath;
import com.amazonaws.services.stepfunctions.builder.states.WaitForTimestamp;
import com.amazonaws.services.stepfunctions.builder.states.WaitForTimestampPath;
import com.amazonaws.services.stepfunctions.builder.states.WaitState;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Validator for a {@link StateMachine} object.
* // TODO Does not check max nesting.
* // TODO Does not validate ARNs against a regex
*/
public class StateMachineValidator {
private final ProblemReporter problemReporter = new ProblemReporter();
private final StateMachine stateMachine;
public StateMachineValidator(StateMachine stateMachine) {
this.stateMachine = stateMachine;
}
public StateMachine validate() {
ValidationContext context = ValidationContext.builder()
.problemReporter(problemReporter)
.parentContext(null)
.identifier("Root")
.location(Location.StateMachine)
.build();
context.assertStringNotEmpty(stateMachine.getStartAt(), PropertyNames.START_AT);
context.assertIsPositiveIfPresent(stateMachine.getTimeoutSeconds(), PropertyNames.TIMEOUT_SECONDS);
context.assertNotEmpty(stateMachine.getStates(), PropertyNames.STATES);
validateStates(context, stateMachine.getStates());
if (!stateMachine.getStates().containsKey(stateMachine.getStartAt())) {
problemReporter.report(new Problem(context,
String.format("%s state does not exist.", PropertyNames.START_AT)));
}
// If basic validation failed then the graph may not be in a good state to be able to validate
if (!problemReporter.hasProblems()) {
new GraphValidator(context, stateMachine).validate();
}
if (problemReporter.hasProblems()) {
throw problemReporter.getException();
}
return stateMachine;
}
/**
* Validates the DFS does not contain unrecoverable cycles (i.e. cycles with no branching logic) or
* does not contain a path to a terminal state.
*/
private final class GraphValidator {
private final Map<String, State> parentVisited;
private final String initialState;
private final Map<String, State> states;
private final Map<String, State> visited = new HashMap<String, State>();
private final ValidationContext currentContext;
public GraphValidator(ValidationContext context, StateMachine stateMachine) {
this(context,
Collections.<String, State>emptyMap(),
stateMachine.getStartAt(),
stateMachine.getStates());
}
private GraphValidator(ValidationContext context,
Map<String, State> parentVisited,
String initialState,
Map<String, State> states) {
this.currentContext = context;
this.parentVisited = parentVisited;
this.initialState = initialState;
this.states = states;
}
public boolean validate() {
boolean pathToTerminal = visit(initialState);
if (parentVisited.isEmpty() && !pathToTerminal) {
problemReporter.report(new Problem(currentContext, "No path to a terminal state exists."));
}
return pathToTerminal;
}
private boolean visit(String stateName) {
ValidationContext stateContext = currentContext.state(stateName);
final State state = states.get(stateName);
if (!parentVisited.containsKey(stateName) && visited.containsKey(stateName)) {
problemReporter.report(new Problem(stateContext, "Cycle detected."));
return false;
} else if (parentVisited.containsKey(stateName)) {
// Cycle but to parent so we may be okay
return false;
}
visited.put(stateName, state);
if (state instanceof ParallelState) {
validateParallelState(stateContext, (ParallelState) state);
}
if (state.isTerminalState()) {
return true;
} else if (state instanceof TransitionState) {
final Transition transition = ((TransitionState) state).getTransition();
return visit(((NextStateTransition) transition).getNextStateName());
} else if (state instanceof ChoiceState) {
return validateChoiceState(stateContext, (ChoiceState) state);
} else {
throw new RuntimeException("Unexpected state type: " + state.getClass().getName());
}
}
private void validateParallelState(ValidationContext stateContext, ParallelState state) {
int index = 0;
for (Branch branch : state.getBranches()) {
new GraphValidator(stateContext.branch(index),
Collections.<String, State>emptyMap(),
branch.getStartAt(),
branch.getStates()).validate();
index++;
}
}
private boolean validateChoiceState(ValidationContext stateContext, ChoiceState choiceState) {
final Map<String, State> merged = mergeParentVisited();
boolean hasPathToTerminal = false;
if (choiceState.getDefaultStateName() != null) {
hasPathToTerminal = new GraphValidator(stateContext, merged, choiceState.getDefaultStateName(), states)
.validate();
}
int index = 0;
for (Choice choice : choiceState.getChoices()) {
final String nextStateName = ((NextStateTransition) choice.getTransition()).getNextStateName();
// It's important hasPathToTerminal is last in the OR so it doesn't short circuit the choice validation
hasPathToTerminal = new GraphValidator(stateContext.choice(index), merged, nextStateName, states).validate()
|| hasPathToTerminal;
index++;
}
return hasPathToTerminal;
}
private Map<String, State> mergeParentVisited() {
final Map<String, State> merged = new HashMap<String, State>(parentVisited.size() + visited.size());
merged.putAll(parentVisited);
merged.putAll(visited);
return merged;
}
}
private void validateStates(ValidationContext parentContext, Map<String, State> states) {
for (Map.Entry<String, State> entry : states.entrySet()) {
parentContext.assertStringNotEmpty(entry.getKey(), "State Name");
entry.getValue().accept(new StateValidationVisitor(states, parentContext.state(entry.getKey())));
}
}
/**
* Validates all the supported states and their nested properties.
*/
private class StateValidationVisitor extends StateVisitor<Void> {
private final ValidationContext currentContext;
private final Map<String, State> states;
private StateValidationVisitor(Map<String, State> states, ValidationContext context) {
this.states = states;
this.currentContext = context;
}
@Override
public Void visit(ChoiceState choiceState) {
currentContext.assertIsValidInputPath(choiceState.getInputPath());
currentContext.assertIsValidOutputPath(choiceState.getOutputPath());
if (choiceState.getDefaultStateName() != null) {
currentContext.assertStringNotEmpty(choiceState.getDefaultStateName(), PropertyNames.DEFAULT_STATE);
assertContainsState(choiceState.getDefaultStateName());
}
currentContext.assertNotEmpty(choiceState.getChoices(), PropertyNames.CHOICES);
int index = 0;
for (Choice choice : choiceState.getChoices()) {
ValidationContext choiceContext = currentContext.choice(index);
validateTransition(choiceContext, choice.getTransition());
validateCondition(choiceContext, choice.getCondition());
index++;
}
return null;
}
private void validateCondition(ValidationContext context, Condition condition) {
context.assertNotNull(condition, "Condition");
if (condition instanceof BinaryCondition) {
validateBinaryCondition(context, (BinaryCondition) condition);
} else if (condition instanceof NAryCondition) {
validateNAryCondition(context, (NAryCondition) condition);
} else if (condition instanceof NotCondition) {
validateCondition(context, ((NotCondition) condition).getCondition());
} else if (condition != null) {
throw new RuntimeException("Unsupported condition type: " + condition.getClass());
}
}
private void validateNAryCondition(ValidationContext context, NAryCondition condition) {
context.assertNotEmpty(condition.getConditions(), "Conditions");
for (Condition nestedCondition : condition.getConditions()) {
validateCondition(context, nestedCondition);
}
}
private void validateBinaryCondition(ValidationContext context, BinaryCondition condition) {
context.assertStringNotEmpty(condition.getVariable(), PropertyNames.VARIABLE);
context.assertIsValidJsonPath(condition.getVariable(), PropertyNames.VARIABLE);
context.assertNotNull(condition.getExpectedValue(), "ExpectedValue");
}
@Override
public Void visit(FailState failState) {
currentContext.assertStringNotEmpty(failState.getCause(), PropertyNames.CAUSE);
return null;
}
@Override
public Void visit(ParallelState parallelState) {
currentContext.assertIsValidInputPath(parallelState.getInputPath());
currentContext.assertIsValidOutputPath(parallelState.getOutputPath());
currentContext.assertIsValidResultPath(parallelState.getResultPath());
validateTransition(parallelState.getTransition());
validateRetriers(parallelState.getRetriers());
validateCatchers(parallelState.getCatchers());
validateBranches(parallelState);
return null;
}
private void validateBranches(ParallelState parallelState) {
currentContext.assertNotEmpty(parallelState.getBranches(), PropertyNames.BRANCHES);
int index = 0;
for (Branch branch : parallelState.getBranches()) {
ValidationContext branchContext = currentContext.branch(index);
validateStates(branchContext, branch.getStates());
if (!branch.getStates().containsKey(branch.getStartAt())) {
problemReporter.report(new Problem(branchContext, String.format("%s references a non existent state.",
PropertyNames.START_AT)));
}
index++;
}
}
@Override
public Void visit(PassState passState) {
currentContext.assertIsValidInputPath(passState.getInputPath());
currentContext.assertIsValidOutputPath(passState.getOutputPath());
currentContext.assertIsValidResultPath(passState.getResultPath());
validateTransition(passState.getTransition());
return null;
}
@Override
public Void visit(SucceedState succeedState) {
currentContext.assertIsValidInputPath(succeedState.getInputPath());
currentContext.assertIsValidOutputPath(succeedState.getOutputPath());
return null;
}
@Override
public Void visit(TaskState taskState) {
currentContext.assertIsValidInputPath(taskState.getInputPath());
currentContext.assertIsValidOutputPath(taskState.getOutputPath());
currentContext.assertIsValidResultPath(taskState.getResultPath());
currentContext.assertIsPositiveIfPresent(taskState.getTimeoutSeconds(), PropertyNames.TIMEOUT_SECONDS);
currentContext.assertIsPositiveIfPresent(taskState.getHeartbeatSeconds(), PropertyNames.HEARTBEAT_SECONDS);
if (taskState.getTimeoutSeconds() != null && taskState.getHeartbeatSeconds() != null) {
if (taskState.getHeartbeatSeconds() >= taskState.getTimeoutSeconds()) {
problemReporter.report(new Problem(currentContext, String.format("%s must be smaller than %s",
PropertyNames.HEARTBEAT_SECONDS,
PropertyNames.TIMEOUT_SECONDS)));
}
}
currentContext.assertStringNotEmpty(taskState.getResource(), PropertyNames.RESOURCE);
validateRetriers(taskState.getRetriers());
validateCatchers(taskState.getCatchers());
validateTransition(taskState.getTransition());
return null;
}
private void validateRetriers(List<Retrier> retriers) {
boolean hasRetryAll = false;
int index = 0;
for (Retrier retrier : retriers) {
ValidationContext retrierContext = currentContext.retrier(index);
if (hasRetryAll) {
problemReporter.report(
new Problem(retrierContext,
String.format("When %s is used in must be in the last Retrier", ErrorCodes.ALL)));
}
// MaxAttempts may be zero
retrierContext.assertIsNotNegativeIfPresent(retrier.getMaxAttempts(), PropertyNames.MAX_ATTEMPTS);
retrierContext.assertIsPositiveIfPresent(retrier.getIntervalSeconds(), PropertyNames.INTERVAL_SECONDS);
if (retrier.getBackoffRate() != null && retrier.getBackoffRate() < 1.0) {
problemReporter.report(new Problem(retrierContext, String.format("%s must be greater than or equal to 1.0",
PropertyNames.BACKOFF_RATE)));
}
hasRetryAll = validateErrorEquals(retrierContext, retrier.getErrorEquals());
index++;
}
}
private void validateCatchers(List<Catcher> catchers) {
boolean hasCatchAll = false;
int index = 0;
for (Catcher catcher : catchers) {
ValidationContext catcherContext = currentContext.catcher(index);
catcherContext.assertIsValidResultPath(catcher.getResultPath());
if (hasCatchAll) {
problemReporter.report(
new Problem(catcherContext,
String.format("When %s is used in must be in the last Catcher", ErrorCodes.ALL)));
}
validateTransition(catcherContext, catcher.getTransition());
hasCatchAll = validateErrorEquals(catcherContext, catcher.getErrorEquals());
index++;
}
}
private boolean validateErrorEquals(ValidationContext currentContext, List<String> errorEquals) {
currentContext.assertNotEmpty(errorEquals, PropertyNames.ERROR_EQUALS);
if (errorEquals.contains(ErrorCodes.ALL)) {
if (errorEquals.size() != 1) {
problemReporter.report(new Problem(currentContext, String.format(
"When %s is used in %s, it must be the only error code in the array",
ErrorCodes.ALL, PropertyNames.ERROR_EQUALS)));
}
return true;
}
return false;
}
@Override
public Void visit(WaitState waitState) {
currentContext.assertIsValidInputPath(waitState.getInputPath());
currentContext.assertIsValidOutputPath(waitState.getOutputPath());
validateTransition(waitState.getTransition());
validateWaitFor(waitState.getWaitFor());
return null;
}
private void validateWaitFor(WaitFor waitFor) {
currentContext.assertNotNull(waitFor, "WaitFor");
if (waitFor instanceof WaitForSeconds) {
currentContext.assertIsPositiveIfPresent(((WaitForSeconds) waitFor).getSeconds(), PropertyNames.SECONDS);
} else if (waitFor instanceof WaitForSecondsPath) {
assertWaitForPath(((WaitForSecondsPath) waitFor).getSecondsPath(), PropertyNames.SECONDS_PATH);
} else if (waitFor instanceof WaitForTimestamp) {
currentContext.assertNotNull(((WaitForTimestamp) waitFor).getTimestamp(), PropertyNames.TIMESTAMP);
} else if (waitFor instanceof WaitForTimestampPath) {
assertWaitForPath(((WaitForTimestampPath) waitFor).getTimestampPath(), PropertyNames.TIMESTAMP_PATH);
} else if (waitFor != null) {
throw new RuntimeException("Unsupported WaitFor strategy: " + waitFor.getClass());
}
}
/**
* TimestampPath and SecondsPath must have a valid reference path.
*/
private void assertWaitForPath(String pathValue, String propertyName) {
currentContext.assertNotNull(pathValue, propertyName);
currentContext.assertIsValidReferencePath(pathValue, propertyName);
}
private void validateTransition(Transition transition) {
validateTransition(currentContext, transition);
}
private void validateTransition(ValidationContext context, Transition transition) {
context.assertNotNull(transition, "Transition");
if (transition instanceof NextStateTransition) {
final String nextStateName = ((NextStateTransition) transition).getNextStateName();
context.assertNotNull(nextStateName, PropertyNames.NEXT);
assertContainsState(context, nextStateName);
}
}
private void assertContainsState(String nextStateName) {
assertContainsState(currentContext, nextStateName);
}
private void assertContainsState(ValidationContext context, String nextStateName) {
if (!states.containsKey(nextStateName)) {
problemReporter.report(new Problem(context, String.format("%s is not a valid state", nextStateName)));
}
}
}
}