package edu.washington.escience.myria.daemon;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import javax.inject.Inject;
import org.apache.reef.driver.client.JobMessageObserver;
import org.apache.reef.driver.context.ActiveContext;
import org.apache.reef.driver.context.ContextConfiguration;
import org.apache.reef.driver.context.FailedContext;
import org.apache.reef.driver.evaluator.AllocatedEvaluator;
import org.apache.reef.driver.evaluator.CompletedEvaluator;
import org.apache.reef.driver.evaluator.EvaluatorRequest;
import org.apache.reef.driver.evaluator.EvaluatorRequestor;
import org.apache.reef.driver.evaluator.FailedEvaluator;
import org.apache.reef.driver.evaluator.JVMProcess;
import org.apache.reef.driver.evaluator.JVMProcessFactory;
import org.apache.reef.driver.task.CompletedTask;
import org.apache.reef.driver.task.FailedTask;
import org.apache.reef.driver.task.RunningTask;
import org.apache.reef.driver.task.TaskConfiguration;
import org.apache.reef.driver.task.TaskMessage;
import org.apache.reef.tang.Configuration;
import org.apache.reef.tang.Configurations;
import org.apache.reef.tang.Injector;
import org.apache.reef.tang.Tang;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.tang.annotations.Unit;
import org.apache.reef.tang.exceptions.BindException;
import org.apache.reef.tang.exceptions.InjectionException;
import org.apache.reef.tang.formats.AvroConfigurationSerializer;
import org.apache.reef.tang.formats.ConfigurationSerializer;
import org.apache.reef.wake.EventHandler;
import org.apache.reef.wake.remote.address.LocalAddressProvider;
import org.apache.reef.wake.time.event.StartTime;
import org.apache.reef.wake.time.event.StopTime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableTable;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Striped;
import com.google.protobuf.InvalidProtocolBufferException;
import edu.washington.escience.myria.MyriaConstants;
import edu.washington.escience.myria.parallel.Server;
import edu.washington.escience.myria.parallel.SocketInfo;
import edu.washington.escience.myria.parallel.Worker;
import edu.washington.escience.myria.proto.ControlProto.ControlMessage;
import edu.washington.escience.myria.proto.TransportProto.TransportMessage;
import edu.washington.escience.myria.tools.MyriaGlobalConfigurationModule;
import edu.washington.escience.myria.tools.MyriaWorkerConfigurationModule;
import edu.washington.escience.myria.util.IPCUtils;
import edu.washington.escience.myria.util.concurrent.RenamingThreadFactory;
/**
* Driver for Myria API server/master. Each worker (including master) is mapped to an Evaluator and a persistent Task.
*/
@Unit
public final class MyriaDriver {
private static final Logger LOGGER = LoggerFactory.getLogger(MyriaDriver.class);
private final LocalAddressProvider addressProvider;
private final EvaluatorRequestor requestor;
private final JVMProcessFactory jvmProcessFactory;
private final JobMessageObserver launcher;
private final Configuration globalConf;
private final Injector globalConfInjector;
private final ImmutableMap<Integer, Configuration> workerConfs;
/**
* This is a workaround for the fact that YARN (and therefore REEF) doesn't allow you to associate an identifier with
* a container request, so we can't tell which EvaluatorRequest an AllocatedEvaluator corresponds to. Therefore, we
* assign worker IDs to AllocatedEvaluators in FIFO request order.
*/
private final Queue<Integer> workerIdsPendingEvaluatorAllocation;
private final ConcurrentMap<Integer, RunningTask> tasksByWorkerId;
private final ConcurrentMap<Integer, ActiveContext> contextsByWorkerId;
private final ConcurrentMap<Integer, AllocatedEvaluator> evaluatorsByWorkerId;
private final AtomicInteger numberWorkersPending;
private final ConcurrentMap<Integer, WorkerAckHandler> addWorkerAckHandlers;
private final ConcurrentMap<Integer, WorkerAckHandler> removeWorkerAckHandlers;
private final ConcurrentMap<Integer, CoordinatorAckHandler> addCoordinatorAckHandlers;
private final ConcurrentMap<Integer, CoordinatorAckHandler> removeCoordinatorAckHandlers;
private final ExecutorService transitionExecutor;
private static final int WORKER_ACK_TIMEOUT_MILLIS = 5000;
static final String DRIVER_PING_MSG = "PING";
static final String DRIVER_PING_ACK_MSG = "PONG";
/**
* Possible states of the Myria driver. Can be one of:
* <dl>
* <du><code>INIT</code></du>
* <dd>Initial state. Ready to request an evaluator.</dd>
* <du><code>PREPARING_MASTER</code></du>
* <dd>Waiting for master Evaluator/Task to be allocated.</dd>
* <du><code>PREPARING_WORKERS</code></du>
* <dd>Waiting for each worker's Evaluator/Task to be allocated.</dd>
* <du><code>READY</code></du>
* <dd>Each per-worker Task is ready to receive queries.</dd>
* </dl>
*/
private enum DriverState {
INIT,
PREPARING_MASTER,
PREPARING_WORKERS,
READY
};
private volatile DriverState state = DriverState.INIT;
private enum TaskState {
PENDING_EVALUATOR_REQUEST,
PENDING_EVALUATOR,
PENDING_CONTEXT,
PENDING_TASK,
PENDING_TASK_RUNNING_ACK,
READY,
FAILED_EVALUATOR_PENDING_TASK_FAILED_ACK,
FAILED_CONTEXT_PENDING_TASK_FAILED_ACK,
FAILED_TASK_PENDING_TASK_FAILED_ACK
};
private final Striped<Lock> workerStateTransitionLocks;
private final ConcurrentMap<Integer, TaskState> workerStates;
private enum TaskStateEvent {
EVALUATOR_SUBMITTED,
EVALUATOR_ALLOCATED,
CONTEXT_ALLOCATED,
TASK_RUNNING,
TASK_RUNNING_ACK,
TASK_FAILED,
TASK_FAILED_ACK,
CONTEXT_FAILED,
EVALUATOR_FAILED
}
private static class TaskStateTransition {
@FunctionalInterface
public interface Handler {
void onTransition(final int workerId, final Object context) throws Exception;
}
public final TaskState newState;
public final Handler handler;
public static TaskStateTransition of(final TaskState newState, final Handler handler) {
return new TaskStateTransition(newState, handler);
}
private TaskStateTransition(final TaskState newState, final Handler handler) {
this.newState = newState;
this.handler = handler;
}
}
private final ImmutableTable<TaskState, TaskStateEvent, TaskStateTransition> taskStateTransitions;
@SuppressWarnings("unchecked")
private ImmutableTable<TaskState, TaskStateEvent, TaskStateTransition>
initializeTaskStateTransitions() {
return new ImmutableTable.Builder<TaskState, TaskStateEvent, TaskStateTransition>()
.put(
TaskState.PENDING_EVALUATOR_REQUEST,
TaskStateEvent.EVALUATOR_SUBMITTED,
TaskStateTransition.of(
TaskState.PENDING_EVALUATOR,
(wid, ctx) -> {
workerIdsPendingEvaluatorAllocation.add(wid);
requestor.submit((EvaluatorRequest) ctx);
}))
.put(
TaskState.PENDING_EVALUATOR,
TaskStateEvent.EVALUATOR_ALLOCATED,
TaskStateTransition.of(
TaskState.PENDING_CONTEXT,
(wid, ctx) -> {
evaluatorsByWorkerId.put(wid, (AllocatedEvaluator) ctx);
allocateWorkerContext(wid);
}))
.put(
TaskState.PENDING_CONTEXT,
TaskStateEvent.CONTEXT_ALLOCATED,
TaskStateTransition.of(
TaskState.PENDING_TASK,
(wid, ctx) -> {
contextsByWorkerId.put(wid, (ActiveContext) ctx);
scheduleTask(wid);
}))
.put(
TaskState.PENDING_CONTEXT,
TaskStateEvent.EVALUATOR_FAILED,
TaskStateTransition.of(
TaskState.PENDING_EVALUATOR,
(wid, ctx) -> {
evaluatorsByWorkerId.remove(wid);
updateDriverStateOnWorkerFailure(wid);
requestWorkerEvaluator(wid);
}))
.put(
TaskState.PENDING_TASK,
TaskStateEvent.TASK_RUNNING,
TaskStateTransition.of(
TaskState.PENDING_TASK_RUNNING_ACK,
(wid, ctx) -> {
tasksByWorkerId.put(wid, (RunningTask) ctx);
recoverWorker(wid);
}))
.put(
TaskState.PENDING_TASK,
TaskStateEvent.TASK_FAILED,
TaskStateTransition.of(
TaskState.PENDING_TASK,
(wid, ctx) -> {
updateDriverStateOnWorkerFailure(wid);
scheduleTask(wid);
}))
.put(
TaskState.PENDING_TASK,
TaskStateEvent.CONTEXT_FAILED,
TaskStateTransition.of(
TaskState.PENDING_CONTEXT,
(wid, ctx) -> {
contextsByWorkerId.remove(wid);
updateDriverStateOnWorkerFailure(wid);
allocateWorkerContext(wid);
}))
.put(
TaskState.PENDING_TASK,
TaskStateEvent.EVALUATOR_FAILED,
TaskStateTransition.of(
TaskState.PENDING_EVALUATOR_REQUEST,
(wid, ctx) -> {
contextsByWorkerId.remove(wid);
evaluatorsByWorkerId.remove(wid);
updateDriverStateOnWorkerFailure(wid);
requestWorkerEvaluator(wid);
}))
.put(
TaskState.PENDING_TASK_RUNNING_ACK,
TaskStateEvent.TASK_RUNNING_ACK,
TaskStateTransition.of(
TaskState.READY, (wid, ctx) -> updateDriverStateOnWorkerReady(wid)))
.put(
TaskState.PENDING_TASK_RUNNING_ACK,
TaskStateEvent.TASK_FAILED,
TaskStateTransition.of(
TaskState.FAILED_TASK_PENDING_TASK_FAILED_ACK,
(wid, ctx) -> {
tasksByWorkerId.remove(wid);
updateDriverStateOnWorkerFailure(wid);
removeWorker(wid);
}))
.put(
TaskState.PENDING_TASK_RUNNING_ACK,
TaskStateEvent.CONTEXT_FAILED,
TaskStateTransition.of(
TaskState.FAILED_CONTEXT_PENDING_TASK_FAILED_ACK,
(wid, ctx) -> {
contextsByWorkerId.remove(wid);
updateDriverStateOnWorkerFailure(wid);
removeWorker(wid);
}))
.put(
TaskState.PENDING_TASK_RUNNING_ACK,
TaskStateEvent.EVALUATOR_FAILED,
TaskStateTransition.of(
TaskState.FAILED_EVALUATOR_PENDING_TASK_FAILED_ACK,
(wid, ctx) -> {
contextsByWorkerId.remove(wid);
evaluatorsByWorkerId.remove(wid);
updateDriverStateOnWorkerFailure(wid);
removeWorker(wid);
}))
.put(
TaskState.READY,
TaskStateEvent.TASK_FAILED,
TaskStateTransition.of(
TaskState.FAILED_TASK_PENDING_TASK_FAILED_ACK,
(wid, ctx) -> {
tasksByWorkerId.remove(wid);
updateDriverStateOnWorkerFailure(wid);
removeWorker(wid);
}))
.put(
TaskState.READY,
TaskStateEvent.CONTEXT_FAILED,
TaskStateTransition.of(
TaskState.FAILED_CONTEXT_PENDING_TASK_FAILED_ACK,
(wid, ctx) -> {
tasksByWorkerId.remove(wid);
contextsByWorkerId.remove(wid);
updateDriverStateOnWorkerFailure(wid);
removeWorker(wid);
}))
.put(
TaskState.READY,
TaskStateEvent.EVALUATOR_FAILED,
TaskStateTransition.of(
TaskState.FAILED_EVALUATOR_PENDING_TASK_FAILED_ACK,
(wid, ctx) -> {
tasksByWorkerId.remove(wid);
contextsByWorkerId.remove(wid);
evaluatorsByWorkerId.remove(wid);
updateDriverStateOnWorkerFailure(wid);
removeWorker(wid);
}))
.put(
TaskState.FAILED_TASK_PENDING_TASK_FAILED_ACK,
TaskStateEvent.TASK_FAILED_ACK,
TaskStateTransition.of(
TaskState.PENDING_TASK,
(wid, ctx) -> {
scheduleTask(wid);
}))
.put(
TaskState.FAILED_CONTEXT_PENDING_TASK_FAILED_ACK,
TaskStateEvent.TASK_FAILED_ACK,
TaskStateTransition.of(
TaskState.PENDING_CONTEXT,
(wid, ctx) -> {
allocateWorkerContext(wid);
}))
.put(
TaskState.FAILED_EVALUATOR_PENDING_TASK_FAILED_ACK,
TaskStateEvent.TASK_FAILED_ACK,
TaskStateTransition.of(
TaskState.PENDING_EVALUATOR_REQUEST,
(wid, ctx) -> {
requestWorkerEvaluator(wid);
}))
.build();
}
public void doTransition(final int workerId, final TaskStateEvent event, final Object context) {
// NB: this lock is reentrant, so any transitions induced by the transition handler will not
// deadlock the thread.
final Lock workerLock = workerStateTransitionLocks.get(workerId);
workerLock.lock();
try {
final TaskState workerState = workerStates.get(workerId);
final TaskStateTransition transition = taskStateTransitions.get(workerState, event);
if (transition != null) {
workerStates.replace(workerId, transition.newState);
LOGGER.info(
"Performing transition on event {} from state {} to state {} (worker ID {}, context {})",
event,
workerState,
transition.newState,
workerId,
context);
try {
transition.handler.onTransition(workerId, context);
} catch (final Exception e) {
throw new RuntimeException(e);
}
} else {
throw new IllegalStateException(
String.format(
"No transition defined for state %s and event %s (worker ID %s)",
workerState,
event,
workerId));
}
} finally {
workerLock.unlock();
}
}
/**
* Schedules a worker transition and its handler to be executed on a separate thread pool.
*
* @param workerId
* @param event
* @param context
*/
public void scheduleTransition(
final int workerId, final TaskStateEvent event, final Object context) {
transitionExecutor.execute(() -> doTransition(workerId, event, context));
}
@FunctionalInterface
private interface WorkerAckHandler {
public void onAck(final int workerId, final int senderId) throws Exception;
}
@FunctionalInterface
private interface CoordinatorAckHandler {
public void onAck(final int workerId) throws Exception;
}
@Inject
public MyriaDriver(
final LocalAddressProvider addressProvider,
final EvaluatorRequestor requestor,
final JVMProcessFactory jvmProcessFactory,
final JobMessageObserver launcher,
final @Parameter(MyriaDriverLauncher.SerializedGlobalConf.class) String serializedGlobalConf)
throws Exception {
this.requestor = requestor;
this.addressProvider = addressProvider;
this.jvmProcessFactory = jvmProcessFactory;
this.launcher = launcher;
globalConf = new AvroConfigurationSerializer().fromString(serializedGlobalConf);
globalConfInjector = Tang.Factory.getTang().newInjector(globalConf);
workerConfs = initializeWorkerConfs();
workerIdsPendingEvaluatorAllocation = new ConcurrentLinkedQueue<>();
tasksByWorkerId = new ConcurrentHashMap<>();
contextsByWorkerId = new ConcurrentHashMap<>();
evaluatorsByWorkerId = new ConcurrentHashMap<>();
numberWorkersPending = new AtomicInteger(workerConfs.size());
addWorkerAckHandlers = new ConcurrentHashMap<>();
removeWorkerAckHandlers = new ConcurrentHashMap<>();
addCoordinatorAckHandlers = new ConcurrentHashMap<>();
removeCoordinatorAckHandlers = new ConcurrentHashMap<>();
workerStates = initializeWorkerStates();
taskStateTransitions = initializeTaskStateTransitions();
workerStateTransitionLocks = Striped.lock(workerConfs.size() + 1); // +1 for coordinator
// Since all worker transitions acquire a per-worker lock, concurrency is limited to the number
// of workers.
transitionExecutor =
Executors.newFixedThreadPool(
workerConfs.size() + 1, new RenamingThreadFactory("WorkerTransitionThreadPool"));
}
private String getMasterHost() throws InjectionException {
final String masterHost =
globalConfInjector.getNamedInstance(MyriaGlobalConfigurationModule.MasterHost.class);
// REEF (org.apache.reef.wake.remote.address.HostnameBasedLocalAddressProvider) will
// unpredictably pick a local DNS name or IP address instead of "localhost" or 127.0.0.1
String reefMasterHost = masterHost;
if (masterHost.equals("localhost") || masterHost.equals("127.0.0.1")) {
try {
reefMasterHost = InetAddress.getByName(addressProvider.getLocalAddress()).getHostName();
LOGGER.info(
"Original host: {}, HostnameBasedLocalAddressProvider returned {}",
masterHost,
reefMasterHost);
} catch (final UnknownHostException e) {
LOGGER.warn("Failed to get canonical hostname for host {}", masterHost);
}
}
return reefMasterHost;
}
private ImmutableMap<Integer, Configuration> initializeWorkerConfs()
throws InjectionException, BindException, IOException {
final ImmutableMap.Builder<Integer, Configuration> workerConfsBuilder =
new ImmutableMap.Builder<>();
final Set<String> serializedWorkerConfs =
globalConfInjector.getNamedInstance(MyriaGlobalConfigurationModule.WorkerConf.class);
final ConfigurationSerializer serializer = new AvroConfigurationSerializer();
for (final String serializedWorkerConf : serializedWorkerConfs) {
final Configuration workerConf = serializer.fromString(serializedWorkerConf);
workerConfsBuilder.put(getIdFromWorkerConf(workerConf), workerConf);
}
return workerConfsBuilder.build();
}
private Integer getIdFromWorkerConf(final Configuration workerConf) throws InjectionException {
final Injector injector = Tang.Factory.getTang().newInjector(workerConf);
return injector.getNamedInstance(MyriaWorkerConfigurationModule.WorkerId.class);
}
private String getHostFromWorkerConf(final Configuration workerConf) throws InjectionException {
final Injector injector = Tang.Factory.getTang().newInjector(workerConf);
final String host = injector.getNamedInstance(MyriaWorkerConfigurationModule.WorkerHost.class);
// REEF (org.apache.reef.wake.remote.address.HostnameBasedLocalAddressProvider) will
// unpredictably pick a local DNS name or IP address instead of "localhost" or 127.0.0.1
String reefHost = host;
if (host.equals("localhost") || host.equals("127.0.0.1")) {
try {
reefHost = InetAddress.getByName(addressProvider.getLocalAddress()).getHostName();
LOGGER.info(
"Original host: {}, HostnameBasedLocalAddressProvider returned {}", host, reefHost);
} catch (final UnknownHostException e) {
LOGGER.warn("Failed to get canonical hostname for host {}", host);
}
}
return reefHost;
}
private ConcurrentMap<Integer, TaskState> initializeWorkerStates() {
final ConcurrentMap<Integer, TaskState> workerStates =
new ConcurrentHashMap<>(workerConfs.size() + 1);
workerStates.put(MyriaConstants.MASTER_ID, TaskState.PENDING_EVALUATOR_REQUEST);
for (final Integer workerId : workerConfs.keySet()) {
workerStates.put(workerId, TaskState.PENDING_EVALUATOR_REQUEST);
}
return workerStates;
}
private SocketInfo getSocketInfoForWorker(final int workerId) throws InjectionException {
final Configuration workerConf = workerConfs.get(workerId);
final Injector injector = Tang.Factory.getTang().newInjector(workerConf);
// we don't use getHostFromWorkerConf() because we want to keep our original hostname
final String host = injector.getNamedInstance(MyriaWorkerConfigurationModule.WorkerHost.class);
final Integer port = injector.getNamedInstance(MyriaWorkerConfigurationModule.WorkerPort.class);
return new SocketInfo(host, port);
}
private void requestWorkerEvaluator(final int workerId) throws InjectionException {
Preconditions.checkArgument(workerId != MyriaConstants.MASTER_ID);
final int jvmMemoryQuotaMB =
(int)
(1024
* globalConfInjector.getNamedInstance(
MyriaGlobalConfigurationModule.WorkerMemoryQuotaGB.class));
final int numberVCores =
globalConfInjector.getNamedInstance(
MyriaGlobalConfigurationModule.WorkerNumberVCores.class);
LOGGER.info(
"Requesting evaluator for worker {} with {} vcores, {} MB memory.",
workerId,
numberVCores,
jvmMemoryQuotaMB);
final Configuration workerConf = workerConfs.get(workerId);
final String hostname = getHostFromWorkerConf(workerConf);
final EvaluatorRequest workerRequest =
EvaluatorRequest.newBuilder()
.setNumber(1)
.setMemory(jvmMemoryQuotaMB)
.setNumberOfCores(numberVCores)
.addNodeName(hostname)
.build();
doTransition(workerId, TaskStateEvent.EVALUATOR_SUBMITTED, workerRequest);
}
private void requestMasterEvaluator() throws InjectionException {
final String masterHost = getMasterHost();
final int jvmMemoryQuotaMB =
(int)
(1024
* globalConfInjector.getNamedInstance(
MyriaGlobalConfigurationModule.MasterMemoryQuotaGB.class));
final int numberVCores =
globalConfInjector.getNamedInstance(
MyriaGlobalConfigurationModule.MasterNumberVCores.class);
LOGGER.info(
"Requesting master evaluator with {} vcores, {} MB memory.",
numberVCores,
jvmMemoryQuotaMB);
final EvaluatorRequest masterRequest =
EvaluatorRequest.newBuilder()
.setNumber(1)
.setMemory(jvmMemoryQuotaMB)
.setNumberOfCores(numberVCores)
.addNodeName(masterHost)
.build();
doTransition(MyriaConstants.MASTER_ID, TaskStateEvent.EVALUATOR_SUBMITTED, masterRequest);
}
private void setJVMOptions(final AllocatedEvaluator evaluator, final boolean isMaster)
throws InjectionException {
final int jvmHeapSizeMinMB =
(int)
(1024
* (isMaster
? globalConfInjector.getNamedInstance(
MyriaGlobalConfigurationModule.MasterJvmHeapSizeMinGB.class)
: globalConfInjector.getNamedInstance(
MyriaGlobalConfigurationModule.WorkerJvmHeapSizeMinGB.class)));
final int jvmHeapSizeMaxMB =
(int)
(1024
* (isMaster
? globalConfInjector.getNamedInstance(
MyriaGlobalConfigurationModule.MasterJvmHeapSizeMaxGB.class)
: globalConfInjector.getNamedInstance(
MyriaGlobalConfigurationModule.WorkerJvmHeapSizeMaxGB.class)));
final Set<String> jvmOptions =
globalConfInjector.getNamedInstance(MyriaGlobalConfigurationModule.JvmOptions.class);
final JVMProcess jvmProcess =
jvmProcessFactory
.newEvaluatorProcess()
.addOption(String.format("-Xms%dm", jvmHeapSizeMinMB))
.addOption(String.format("-Xmx%dm", jvmHeapSizeMaxMB))
// for native libraries
.addOption("-Djava.library.path=./reef/global");
for (final String option : jvmOptions) {
jvmProcess.addOption(option);
}
evaluator.setProcess(jvmProcess);
}
private void launchMaster() throws InjectionException {
Preconditions.checkState(state == DriverState.PREPARING_MASTER);
requestMasterEvaluator();
}
private void launchWorkers() throws InjectionException {
Preconditions.checkState(state == DriverState.PREPARING_WORKERS);
for (final Integer workerId : workerConfs.keySet()) {
requestWorkerEvaluator(workerId);
}
}
private void allocateWorkerContext(final int workerId) throws InjectionException {
Preconditions.checkState(evaluatorsByWorkerId.containsKey(workerId));
final AllocatedEvaluator evaluator = evaluatorsByWorkerId.get(workerId);
LOGGER.info(
"Launching context for worker ID {} on {}",
workerId,
evaluator.getEvaluatorDescriptor().getNodeDescriptor().getName());
Preconditions.checkState(!contextsByWorkerId.containsKey(workerId));
Configuration contextConf =
ContextConfiguration.CONF.set(ContextConfiguration.IDENTIFIER, workerId + "").build();
setJVMOptions(evaluator, (workerId == MyriaConstants.MASTER_ID));
if (workerId != MyriaConstants.MASTER_ID) {
contextConf = Configurations.merge(contextConf, workerConfs.get(workerId));
}
evaluator.submitContext(Configurations.merge(contextConf, globalConf));
}
private void scheduleTask(final int workerId) {
Preconditions.checkState(contextsByWorkerId.containsKey(workerId));
final ActiveContext context = contextsByWorkerId.get(workerId);
LOGGER.info(
"Scheduling task for worker ID {} on context {}, evaluator {}",
workerId,
context.getId(),
context.getEvaluatorId());
Configuration taskConf;
if (workerId == MyriaConstants.MASTER_ID) {
Preconditions.checkState(state == DriverState.PREPARING_MASTER);
taskConf =
TaskConfiguration.CONF
.set(TaskConfiguration.TASK, MasterDaemon.class)
.set(TaskConfiguration.IDENTIFIER, workerId + "")
.set(TaskConfiguration.ON_SEND_MESSAGE, Server.class)
.set(TaskConfiguration.ON_MESSAGE, Server.class)
.set(TaskConfiguration.ON_CLOSE, MasterDaemon.class)
.build();
} else {
Preconditions.checkState(
state == DriverState.PREPARING_WORKERS || state == DriverState.READY);
taskConf =
TaskConfiguration.CONF
.set(TaskConfiguration.TASK, Worker.class)
.set(TaskConfiguration.IDENTIFIER, workerId + "")
.set(TaskConfiguration.ON_SEND_MESSAGE, Worker.class)
.set(TaskConfiguration.ON_MESSAGE, Worker.DriverMessageHandler.class)
.set(TaskConfiguration.ON_CLOSE, Worker.TaskCloseHandler.class)
.build();
}
context.submitTask(taskConf);
}
private ImmutableSet<Integer> getAliveWorkers() {
final ImmutableSet.Builder<Integer> builder = ImmutableSet.builder();
workerStates.forEach(
(wid, state) -> {
if (!wid.equals(MyriaConstants.MASTER_ID) && state.equals(TaskState.READY)) {
builder.add(wid);
}
});
return builder.build();
}
private boolean sendMessageToWorker(final int workerId, final TransportMessage message) {
boolean messageSent = false;
// if the worker we're sending this message to is in a state transition, we abort
if (workerStateTransitionLocks.get(workerId).tryLock()) {
try {
final RunningTask workerToNotifyTask = tasksByWorkerId.get(workerId);
if (workerToNotifyTask != null) {
workerToNotifyTask.send(message.toByteArray());
messageSent = true;
}
} finally {
workerStateTransitionLocks.get(workerId).unlock();
}
} else {
LOGGER.warn(
"worker {} is in a state transition (current state: {}), aborting send of message: {}",
workerId,
workerStates.get(workerId),
message);
}
return messageSent;
}
private void sendMessageToCoordinator(final TransportMessage message) {
// The coordinator can never be in a state transition after it comes up, so it should always be
// safe to send it messages concurrently.
Preconditions.checkState(workerStates.get(MyriaConstants.MASTER_ID).equals(TaskState.READY));
final RunningTask coordinatorTask = tasksByWorkerId.get(MyriaConstants.MASTER_ID);
coordinatorTask.send(message.toByteArray());
}
private void registerWorkerAddAckHandler(final int workerId, final WorkerAckHandler handler) {
addWorkerAckHandlers.put(workerId, handler);
}
private void registerWorkerRemoveAckHandler(final int workerId, final WorkerAckHandler handler) {
removeWorkerAckHandlers.put(workerId, handler);
}
private void registerCoordinatorAddAckHandler(
final int workerId, final CoordinatorAckHandler handler) {
addCoordinatorAckHandlers.put(workerId, handler);
}
private void registerCoordinatorRemoveAckHandler(
final int workerId, final CoordinatorAckHandler handler) {
removeCoordinatorAckHandlers.put(workerId, handler);
}
private void unregisterWorkerAddAckHandler(final int workerId) {
addWorkerAckHandlers.remove(workerId);
}
private void unregisterWorkerRemoveAckHandler(final int workerId) {
removeWorkerAckHandlers.remove(workerId);
}
private void unregisterCoordinatorAddAckHandler(final int workerId) {
addCoordinatorAckHandlers.remove(workerId);
}
private void unregisterCoordinatorRemoveAckHandler(final int workerId) {
removeCoordinatorAckHandlers.remove(workerId);
}
private void recoverWorker(final int workerId) throws InterruptedException {
if (workerId != MyriaConstants.MASTER_ID) {
// this is obviously racy but it doesn't matter since we timeout on acks
final ImmutableSet<Integer> aliveWorkers = getAliveWorkers();
final CountDownLatch acksPending = new CountDownLatch(aliveWorkers.size());
final Set<Integer> ackedWorkers = Sets.newConcurrentHashSet();
registerWorkerAddAckHandler(
workerId,
(wid, sid) -> {
ackedWorkers.add(sid);
acksPending.countDown();
});
LOGGER.info(
"Sending ADD_WORKER for worker {} to all {} alive workers",
workerId,
aliveWorkers.size());
for (final Integer aliveWorkerId : aliveWorkers) {
notifyWorkerOnRecovery(workerId, aliveWorkerId);
}
final boolean timedOut = !acksPending.await(WORKER_ACK_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS);
if (timedOut) {
LOGGER.info(
"Timed out after {} ms while waiting for {} acks for ADD_WORKER on worker {}",
WORKER_ACK_TIMEOUT_MILLIS,
aliveWorkers.size(),
workerId);
}
// Strictly speaking, this is incorrect since a stale ack could arrive during the next phase
// in this state. We would need per-worker epoch counters to prevent this (advance the epoch
// on each state transition, check the epoch of an ack message when it arrives and reject it
// if stale). If this sort of bug ever shows up in tests, we'll consider this solution.
unregisterWorkerAddAckHandler(workerId);
LOGGER.info(
"Received {} of expected {} acks for ADD_WORKER on worker {}",
ackedWorkers.size(),
aliveWorkers.size(),
workerId);
final CountDownLatch coordinatorAcked = new CountDownLatch(1);
registerCoordinatorAddAckHandler(
workerId,
(wid) -> {
coordinatorAcked.countDown();
});
notifyCoordinatorOnRecovery(workerId, ackedWorkers);
coordinatorAcked.await();
unregisterCoordinatorAddAckHandler(workerId);
// we need to perform the transition in our own thread or we can't re-enter the lock
doTransition(workerId, TaskStateEvent.TASK_RUNNING_ACK, ackedWorkers);
} else {
// coordinator can't get any acks when it starts
doTransition(workerId, TaskStateEvent.TASK_RUNNING_ACK, ImmutableSet.of());
}
}
private void removeWorker(final int workerId) throws InterruptedException {
Preconditions.checkState(workerId != MyriaConstants.MASTER_ID);
// this is obviously racy but it doesn't matter since we timeout on acks
final ImmutableSet<Integer> aliveWorkers = getAliveWorkers();
final CountDownLatch acksPending = new CountDownLatch(aliveWorkers.size());
final Set<Integer> ackedWorkers = Sets.newConcurrentHashSet();
registerWorkerRemoveAckHandler(
workerId,
(wid, sid) -> {
ackedWorkers.add(sid);
acksPending.countDown();
});
LOGGER.info(
"Sending REMOVE_WORKER for worker {} to all {} alive workers",
workerId,
aliveWorkers.size());
for (final Integer aliveWorkerId : aliveWorkers) {
notifyWorkerOnFailure(workerId, aliveWorkerId);
}
final boolean timedOut = !acksPending.await(WORKER_ACK_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS);
if (timedOut) {
LOGGER.info(
"Timed out after {} ms while waiting for {} acks for REMOVE_WORKER on worker {}",
WORKER_ACK_TIMEOUT_MILLIS,
aliveWorkers.size(),
workerId);
}
// Strictly speaking, this is incorrect since a stale ack could arrive during the next phase
// in this state. We would need per-worker epoch counters to prevent this (advance the epoch
// on each state transition, check the epoch of an ack message when it arrives and reject it
// if stale). If this sort of bug ever shows up in tests, we'll consider this solution.
unregisterWorkerRemoveAckHandler(workerId);
LOGGER.info(
"Received {} of expected {} acks for REMOVE_WORKER on worker {}",
ackedWorkers.size(),
aliveWorkers.size(),
workerId);
final CountDownLatch coordinatorAcked = new CountDownLatch(1);
registerCoordinatorRemoveAckHandler(
workerId,
(wid) -> {
coordinatorAcked.countDown();
});
notifyCoordinatorOnFailure(workerId, ackedWorkers);
coordinatorAcked.await();
unregisterCoordinatorRemoveAckHandler(workerId);
// we need to perform the transition in our own thread or we can't re-enter the lock
doTransition(workerId, TaskStateEvent.TASK_FAILED_ACK, ackedWorkers);
}
private void notifyWorkerOnFailure(final int workerId, final int workerToNotifyId) {
// we should never get here on coordinator failure
Preconditions.checkArgument(workerId != MyriaConstants.MASTER_ID);
LOGGER.info("Sending REMOVE_WORKER for worker {} to worker {}", workerId, workerToNotifyId);
final TransportMessage workerFailed = IPCUtils.removeWorkerTM(workerId, null);
if (!sendMessageToWorker(workerToNotifyId, workerFailed)) {
LOGGER.warn(
"Unable to send REMOVE_WORKER for worker {} to worker {}", workerId, workerToNotifyId);
}
}
private void notifyWorkerOnRecovery(final int workerId, final int workerToNotifyId) {
// we should never get here on coordinator failure
Preconditions.checkArgument(workerId != MyriaConstants.MASTER_ID);
SocketInfo si;
try {
si = getSocketInfoForWorker(workerId);
} catch (final InjectionException e) {
LOGGER.error("Failed to get SocketInfo for worker {}:\n{}", workerId, e);
return;
}
LOGGER.info("Sending ADD_WORKER for worker {} to worker {}", workerId, workerToNotifyId);
final TransportMessage workerRecovered = IPCUtils.addWorkerTM(workerId, si, null);
if (!sendMessageToWorker(workerToNotifyId, workerRecovered)) {
LOGGER.warn(
"Unable to send ADD_WORKER for worker {} to worker {}", workerId, workerToNotifyId);
}
}
private void notifyCoordinatorOnFailure(final int workerId, final Set<Integer> ackedWorkers) {
Preconditions.checkState(workerId != MyriaConstants.MASTER_ID);
LOGGER.info("Sending REMOVE_WORKER for worker {} to coordinator", workerId);
final TransportMessage workerFailed = IPCUtils.removeWorkerTM(workerId, ackedWorkers);
sendMessageToCoordinator(workerFailed);
}
private void notifyCoordinatorOnRecovery(final int workerId, final Set<Integer> ackedWorkers) {
Preconditions.checkState(workerId != MyriaConstants.MASTER_ID);
SocketInfo si;
try {
si = getSocketInfoForWorker(workerId);
} catch (final InjectionException e) {
LOGGER.error("Failed to get SocketInfo for worker {}:\n{}", workerId, e);
return;
}
LOGGER.info("Sending ADD_WORKER for worker {} to coordinator", workerId);
final TransportMessage workerRecovered = IPCUtils.addWorkerTM(workerId, si, ackedWorkers);
sendMessageToCoordinator(workerRecovered);
}
private void onWorkerAddAck(final int workerId, final int senderId) {
final WorkerAckHandler ackHandler =
addWorkerAckHandlers.getOrDefault(
workerId,
(wid, sid)
-> LOGGER.warn(
"No worker handler registered for ADD_WORKER_ACK (worker ID {}, sender ID {})",
wid,
sid));
try {
ackHandler.onAck(workerId, senderId);
} catch (final Exception e) {
throw new RuntimeException(e);
}
}
private void onWorkerRemoveAck(final int workerId, final int senderId) {
final WorkerAckHandler ackHandler =
removeWorkerAckHandlers.getOrDefault(
workerId,
(wid, sid)
-> LOGGER.warn(
"No worker handler registered for REMOVE_WORKER_ACK (worker ID {}, sender ID {})",
wid,
sid));
try {
ackHandler.onAck(workerId, senderId);
} catch (final Exception e) {
throw new RuntimeException(e);
}
}
private void onCoordinatorAddAck(final int workerId) {
final CoordinatorAckHandler ackHandler =
addCoordinatorAckHandlers.getOrDefault(
workerId,
(wid)
-> LOGGER.warn(
"No coordinator handler registered for ADD_WORKER_ACK (worker ID {})", wid));
try {
ackHandler.onAck(workerId);
} catch (final Exception e) {
throw new RuntimeException(e);
}
}
private void onCoordinatorRemoveAck(final int workerId) {
final CoordinatorAckHandler ackHandler =
removeCoordinatorAckHandlers.getOrDefault(
workerId,
(wid)
-> LOGGER.warn(
"No coordinator handler registered for REMOVE_WORKER_ACK (worker ID {})", wid));
try {
ackHandler.onAck(workerId);
} catch (final Exception e) {
throw new RuntimeException(e);
}
}
private void updateDriverStateOnWorkerReady(final int workerId) throws InjectionException {
Preconditions.checkState(tasksByWorkerId.containsKey(workerId));
String message = String.format("Worker %s ready", workerId);
launcher.sendMessageToClient(message.getBytes(StandardCharsets.UTF_8));
if (state == DriverState.PREPARING_MASTER) {
Preconditions.checkState(workerId == MyriaConstants.MASTER_ID);
message = String.format("Master is running, starting %s workers...", workerConfs.size());
LOGGER.info(message);
launcher.sendMessageToClient(message.getBytes(StandardCharsets.UTF_8));
state = DriverState.PREPARING_WORKERS;
launchWorkers();
} else if (state == DriverState.PREPARING_WORKERS) {
Preconditions.checkState(workerId != MyriaConstants.MASTER_ID);
if (numberWorkersPending.decrementAndGet() == 0) {
message = String.format("All %s workers running, ready for queries...", workerConfs.size());
LOGGER.info(message);
launcher.sendMessageToClient(message.getBytes(StandardCharsets.UTF_8));
state = DriverState.READY;
}
}
}
private void updateDriverStateOnWorkerFailure(final int workerId) {
final String message = String.format("Worker %s failed", workerId);
launcher.sendMessageToClient(message.getBytes(StandardCharsets.UTF_8));
if (workerId == MyriaConstants.MASTER_ID) {
throw new RuntimeException("Shutting down driver on coordinator failure");
} else if (state == DriverState.PREPARING_WORKERS) {
LOGGER.warn(
"Worker failed in PREPARING_WORKERS phase, {} workers pending...",
numberWorkersPending.get());
}
}
/**
* The driver is ready to run.
*/
final class StartHandler implements EventHandler<StartTime> {
@Override
public void onNext(final StartTime startTime) {
LOGGER.info("Driver started at {}", startTime);
Preconditions.checkState(state == DriverState.INIT);
state = DriverState.PREPARING_MASTER;
try {
launchMaster();
} catch (final InjectionException e) {
throw new RuntimeException(e);
}
}
}
/**
* Shutting down the job driver: close the evaluators.
*/
final class StopHandler implements EventHandler<StopTime> {
@Override
public void onNext(final StopTime stopTime) {
LOGGER.info("Driver stopped at {}", stopTime);
for (final RunningTask task : tasksByWorkerId.values()) {
task.getActiveContext().close();
}
}
}
final class EvaluatorAllocatedHandler implements EventHandler<AllocatedEvaluator> {
@Override
public void onNext(final AllocatedEvaluator evaluator) {
final String node = evaluator.getEvaluatorDescriptor().getNodeDescriptor().getName();
LOGGER.info("Allocated evaluator {} on node {}", evaluator.getId(), node);
final Integer workerId = workerIdsPendingEvaluatorAllocation.poll();
Preconditions.checkState(workerId != null, "No worker ID waiting for an evaluator!");
scheduleTransition(workerId, TaskStateEvent.EVALUATOR_ALLOCATED, evaluator);
}
}
final class CompletedEvaluatorHandler implements EventHandler<CompletedEvaluator> {
@Override
public void onNext(final CompletedEvaluator eval) {
throw new IllegalStateException("Unexpected CompletedEvaluator: " + eval.getId());
}
}
final class EvaluatorFailureHandler implements EventHandler<FailedEvaluator> {
@Override
public void onNext(final FailedEvaluator failedEvaluator) {
LOGGER.warn("FailedEvaluator: {}", failedEvaluator);
// respawn evaluator and reschedule task if configured
final List<FailedContext> failedContexts = failedEvaluator.getFailedContextList();
// we should have at most one context in the list (since we only allocate the root context)
if (failedContexts.size() > 0) {
Preconditions.checkState(failedContexts.size() == 1);
final FailedContext failedContext = failedContexts.get(0);
final int workerId = Integer.valueOf(failedContext.getId());
scheduleTransition(workerId, TaskStateEvent.EVALUATOR_FAILED, failedEvaluator);
} else {
throw new IllegalStateException(
"Could not find worker ID for failed evaluator: " + failedEvaluator);
}
}
}
final class ActiveContextHandler implements EventHandler<ActiveContext> {
@Override
public void onNext(final ActiveContext context) {
final String host = context.getEvaluatorDescriptor().getNodeDescriptor().getName();
LOGGER.info("Context {} available on node {}", context.getId(), host);
final int workerId = Integer.valueOf(context.getId());
scheduleTransition(workerId, TaskStateEvent.CONTEXT_ALLOCATED, context);
}
}
final class ContextFailureHandler implements EventHandler<FailedContext> {
@Override
public void onNext(final FailedContext failedContext) {
LOGGER.error("FailedContext: {}", failedContext);
final int workerId = Integer.valueOf(failedContext.getId());
scheduleTransition(workerId, TaskStateEvent.CONTEXT_FAILED, failedContext);
}
}
final class RunningTaskHandler implements EventHandler<RunningTask> {
@Override
public void onNext(final RunningTask task) {
LOGGER.info("Running task: {}", task.getId());
final int workerId = Integer.valueOf(task.getId());
scheduleTransition(workerId, TaskStateEvent.TASK_RUNNING, task);
}
}
final class CompletedTaskHandler implements EventHandler<CompletedTask> {
@Override
public void onNext(final CompletedTask task) {
throw new IllegalStateException("Unexpected CompletedTask: " + task.getId());
}
}
final class TaskFailureHandler implements EventHandler<FailedTask> {
@Override
public void onNext(final FailedTask failedTask) {
LOGGER.warn(
"FailedTask (ID {}): {}\n{}",
failedTask.getId(),
failedTask.getMessage(),
failedTask.getReason());
final int workerId = Integer.valueOf(failedTask.getId());
scheduleTransition(workerId, TaskStateEvent.TASK_FAILED, failedTask);
}
}
// NB: REEF only uses a single TaskMessage dispatch thread per Evaluator, so TaskMessage handlers
// should never block!
final class TaskMessageHandler implements EventHandler<TaskMessage> {
@Override
public void onNext(final TaskMessage taskMessage) {
final int senderId = Integer.valueOf(taskMessage.getMessageSourceID());
TransportMessage m;
try {
m = TransportMessage.parseFrom(taskMessage.get());
} catch (final InvalidProtocolBufferException e) {
LOGGER.warn("Could not parse TransportMessage from task message", e);
return;
}
final ControlMessage controlM = m.getControlMessage();
// We received a failed worker ack or recovered worker ack from the coordinator.
final int workerId = controlM.getWorkerId();
LOGGER.info(
"Received {} for worker {} from worker {}", controlM.getType(), workerId, senderId);
if (controlM.getType() == ControlMessage.Type.REMOVE_WORKER_ACK) {
if (senderId == MyriaConstants.MASTER_ID) {
onCoordinatorRemoveAck(workerId);
} else {
onWorkerRemoveAck(workerId, senderId);
}
} else if (controlM.getType() == ControlMessage.Type.ADD_WORKER_ACK) {
if (senderId == MyriaConstants.MASTER_ID) {
onCoordinatorAddAck(workerId);
} else {
onWorkerAddAck(workerId, senderId);
}
} else {
throw new IllegalStateException(
"Expected control message to be ADD_WORKER_ACK or REMOVE_WORKER_ACK, got "
+ controlM.getType());
}
}
}
final class ClientMessageHandler implements EventHandler<byte[]> {
@Override
public void onNext(final byte[] message) {
final String msgStr = new String(message, StandardCharsets.UTF_8);
Preconditions.checkArgument(msgStr.equals(MyriaDriver.DRIVER_PING_MSG));
LOGGER.info("Message from Myria launcher: {}", msgStr);
launcher.sendMessageToClient(
MyriaDriver.DRIVER_PING_ACK_MSG.getBytes(StandardCharsets.UTF_8));
}
}
final class ClientCloseHandler implements EventHandler<Void> {
@Override
public void onNext(final Void aVoid) {
LOGGER.info("Driver forcibly closed, shutting down evaluators...");
for (final AllocatedEvaluator evaluator : evaluatorsByWorkerId.values()) {
evaluator.close();
}
}
}
}