package netflix.ocelli.util;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.Observable;
import rx.Observable.OnSubscribe;
import rx.Subscriber;
import rx.functions.Action1;
import rx.functions.Action2;
import rx.functions.Func0;
import rx.functions.Func1;
import rx.subjects.PublishSubject;
public class StateMachine<T, E> implements Action1<E> {
private static final Logger LOG = LoggerFactory.getLogger(StateMachine.class);
public static class State<T, E> {
private String name;
private Func1<T, Observable<E>> enter;
private Func1<T, Observable<E>> exit;
private Map<E, State<T, E>> transitions = new HashMap<E, State<T, E>>();
private Set<E> ignore = new HashSet<E>();
public static <T, E> State<T, E> create(String name) {
return new State<T, E>(name);
}
public State(String name) {
this.name = name;
}
public State<T, E> onEnter(Func1<T, Observable<E>> func) {
this.enter = func;
return this;
}
public State<T, E> onExit(Func1<T, Observable<E>> func) {
this.exit = func;
return this;
}
public State<T, E> transition(E event, State<T, E> state) {
transitions.put(event, state);
return this;
}
public State<T, E> ignore(E event) {
ignore.add(event);
return this;
}
Observable<E> enter(T context) {
if (enter != null)
return enter.call(context);
return Observable.empty();
}
Observable<E> exit(T context) {
if (exit != null)
exit.call(context);
return Observable.empty();
}
State<T, E> next(E event) {
return transitions.get(event);
}
public String toString() {
return name;
}
}
private volatile State<T, E> state;
private final T context;
private final PublishSubject<E> events = PublishSubject.create();
public static <T, E> StateMachine<T, E> create(T context, State<T, E> initial) {
return new StateMachine<T, E>(context, initial);
}
public StateMachine(T context, State<T, E> initial) {
this.state = initial;
this.context = context;
}
public Observable<Void> start() {
return Observable.create(new OnSubscribe<Void>() {
@Override
public void call(Subscriber<? super Void> sub) {
sub.add(events.collect(new Func0<T>() {
@Override
public T call() {
return context;
}
}, new Action2<T, E>() {
@Override
public void call(T context, E event) {
LOG.trace("{} : {}({})", context, state, event);
final State<T, E> next = state.next(event);
if (next != null) {
state.exit(context);
state = next;
next.enter(context).subscribe(StateMachine.this);
}
else if (!state.ignore.contains(event)) {
LOG.warn("Unexpected event {} in state {} for {} ", event, state, context);
}
}
})
.subscribe());
state.enter(context);
}
});
}
@Override
public void call(E event) {
events.onNext(event);
}
public State<T, E> getState() {
return state;
}
public T getContext() {
return context;
}
}