package org.robotninjas.barge.state;
import com.google.common.base.Throwables;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListenableFutureTask;
import org.jetlang.fibers.Fiber;
import org.robotninjas.barge.RaftException;
import org.robotninjas.barge.RaftExecutor;
import org.robotninjas.barge.api.AppendEntries;
import org.robotninjas.barge.api.AppendEntriesResponse;
import org.robotninjas.barge.api.RequestVote;
import org.robotninjas.barge.api.RequestVoteResponse;
import org.robotninjas.barge.log.RaftLog;
import org.slf4j.MDC;
import javax.annotation.Nonnull;
import javax.annotation.concurrent.NotThreadSafe;
import javax.inject.Inject;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.Executor;
import static com.google.common.base.Preconditions.checkNotNull;
@NotThreadSafe
class RaftStateContext implements Raft {
private final StateFactory stateFactory;
private final Executor executor;
private final String name;
private final Set<StateTransitionListener> listeners = Sets.newConcurrentHashSet();
private final Set<RaftProtocolListener> protocolListeners = Sets.newConcurrentHashSet();
private volatile StateType state;
private volatile State delegate;
private boolean stop;
@Inject
RaftStateContext(String name, StateFactory stateFactory, @RaftExecutor Fiber executor, Set<StateTransitionListener> listeners, Set<RaftProtocolListener> protocolListeners) {
MDC.put("self", name);
this.stateFactory = stateFactory;
this.executor = executor;
this.name = name;
this.listeners.add(new LogListener());
this.listeners.addAll(listeners);
this.protocolListeners.addAll(protocolListeners);
}
RaftStateContext(RaftLog log, StateFactory stateFactory, Fiber executor, Set<StateTransitionListener> listeners) {
this(log.self().toString(), stateFactory, executor, listeners);
}
RaftStateContext(String name, StateFactory stateFactory, Fiber executor, Set<StateTransitionListener> listeners) {
this(name, stateFactory, executor, listeners, Collections.<RaftProtocolListener>emptySet());
}
@Override
public ListenableFuture<StateType> init() {
ListenableFutureTask<StateType> init = ListenableFutureTask.create(new Callable<StateType>() {
@Override
public StateType call() {
setState(null, StateType.START);
return StateType.START;
}
});
executor.execute(init);
notifiesInit();
return init;
}
@Override
@Nonnull
public RequestVoteResponse requestVote(@Nonnull final RequestVote request) {
checkNotNull(request);
ListenableFutureTask<RequestVoteResponse> response = ListenableFutureTask.create(new Callable<RequestVoteResponse>() {
@Override
public RequestVoteResponse call() throws Exception {
return delegate.requestVote(RaftStateContext.this, request);
}
});
executor.execute(response);
try {
return response.get();
} catch (Exception e) {
throw Throwables.propagate(e);
} finally {
notifyRequestVote(request);
}
}
@Override
@Nonnull
public AppendEntriesResponse appendEntries(@Nonnull final AppendEntries request) {
checkNotNull(request);
ListenableFutureTask<AppendEntriesResponse> response = ListenableFutureTask.create(new Callable<AppendEntriesResponse>() {
@Override
public AppendEntriesResponse call() throws Exception {
return delegate.appendEntries(RaftStateContext.this, request);
}
});
executor.execute(response);
try {
return response.get();
} catch (Exception e) {
throw Throwables.propagate(e);
} finally {
notifyAppendEntries(request);
}
}
@Override
@Nonnull
public ListenableFuture<Object> commitOperation(@Nonnull final byte[] op) throws RaftException {
checkNotNull(op);
ListenableFutureTask<Object> response = ListenableFutureTask.create(new Callable<Object>() {
@Override
public Object call() throws Exception {
return delegate.commitOperation(RaftStateContext.this, op);
}
});
executor.execute(response);
notifyCommit(op);
return response;
}
public synchronized void setState(State oldState, @Nonnull StateType state) {
if (this.delegate != oldState) {
notifiesInvalidTransition(oldState);
throw new IllegalStateException();
}
if (stop) {
state = StateType.STOPPED;
notifiesStop();
}
if (this.delegate != null) {
this.delegate.destroy(this);
}
this.state = checkNotNull(state);
delegate = stateFactory.makeState(state);
MDC.put("state", this.state.toString());
notifiesChangeState(oldState);
delegate.init(this);
}
@Override
public void addTransitionListener(@Nonnull StateTransitionListener transitionListener) {
listeners.add(transitionListener);
}
@Override public void addRaftProtocolListener(@Nonnull RaftProtocolListener protocolListener) {
protocolListeners.add(protocolListener);
}
@Override
@Nonnull
public StateType type() {
return state;
}
public synchronized void stop() {
stop = true;
if (this.delegate != null) {
this.delegate.doStop(this);
}
}
@Override
public String toString() {
return name;
}
private void notifiesStop() {
for (StateTransitionListener listener : listeners) {
listener.stop(this);
}
}
private void notifiesInvalidTransition(State oldState) {
for (StateTransitionListener listener : listeners) {
listener.invalidTransition(this, state, (oldState == null) ? null : oldState.type());
}
}
private void notifiesChangeState(State oldState) {
for (StateTransitionListener listener : listeners) {
listener.changeState(this, (oldState == null) ? null : oldState.type(), state);
}
}
private void notifiesInit() {
for (RaftProtocolListener protocolListener : protocolListeners) {
protocolListener.init(this);
}
}
private void notifyAppendEntries(AppendEntries request) {
for (RaftProtocolListener protocolListener : protocolListeners) {
protocolListener.appendEntries(this, request);
}
}
private void notifyRequestVote(RequestVote vote) {
for (RaftProtocolListener protocolListener : protocolListeners) {
protocolListener.requestVote(this, vote);
}
}
private void notifyCommit(byte[] bytes) {
for (RaftProtocolListener protocolListener : protocolListeners) {
protocolListener.commit(this, bytes);
}
}
}