/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.taskmanager;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.blob.BlobKey;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.concurrent.Executors;
import org.apache.flink.runtime.concurrent.impl.FlinkCompletableFuture;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
import org.apache.flink.runtime.execution.CancelTaskException;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.executiongraph.JobInformation;
import org.apache.flink.runtime.executiongraph.TaskInformation;
import org.apache.flink.runtime.filecache.FileCache;
import org.apache.flink.runtime.instance.ActorGateway;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.NetworkEnvironment;
import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker;
import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
import org.apache.flink.runtime.jobmanager.PartitionProducerDisposedException;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.messages.TaskManagerMessages;
import org.apache.flink.runtime.messages.TaskMessages;
import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
import org.apache.flink.runtime.query.TaskKvStateRegistry;
import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
import org.apache.flink.util.SerializedValue;
import org.apache.flink.util.TestLogger;
import org.apache.flink.util.WrappingRuntimeException;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import scala.concurrent.duration.FiniteDuration;
import javax.annotation.Nonnull;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URL;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* Tests for the Task, which make sure that correct state transitions happen,
* and failures are correctly handled.
*
* All tests here have a set of mock actors for TaskManager, JobManager, and
* execution listener, which simply put the messages in a queue to be picked
* up by the test and validated.
*/
public class TaskTest extends TestLogger {
private static OneShotLatch awaitLatch;
private static OneShotLatch triggerLatch;
private static OneShotLatch cancelLatch;
private ActorGateway taskManagerGateway;
private ActorGateway jobManagerGateway;
private ActorGateway listenerGateway;
private ActorGatewayTaskExecutionStateListener listener;
private ActorGatewayTaskManagerActions taskManagerConnection;
private BlockingQueue<Object> taskManagerMessages;
private BlockingQueue<Object> jobManagerMessages;
private BlockingQueue<Object> listenerMessages;
@Before
public void createQueuesAndActors() {
taskManagerMessages = new LinkedBlockingQueue<>();
jobManagerMessages = new LinkedBlockingQueue<>();
listenerMessages = new LinkedBlockingQueue<>();
taskManagerGateway = new ForwardingActorGateway(taskManagerMessages);
jobManagerGateway = new ForwardingActorGateway(jobManagerMessages);
listenerGateway = new ForwardingActorGateway(listenerMessages);
listener = new ActorGatewayTaskExecutionStateListener(listenerGateway);
taskManagerConnection = new ActorGatewayTaskManagerActions(taskManagerGateway);
awaitLatch = new OneShotLatch();
triggerLatch = new OneShotLatch();
cancelLatch = new OneShotLatch();
}
@After
public void clearActorsAndMessages() {
jobManagerMessages = null;
taskManagerMessages = null;
listenerMessages = null;
taskManagerGateway = null;
jobManagerGateway = null;
listenerGateway = null;
}
// ------------------------------------------------------------------------
// Tests
// ------------------------------------------------------------------------
@Test
public void testRegularExecution() {
try {
Task task = createTask(TestInvokableCorrect.class);
// task should be new and perfect
assertEquals(ExecutionState.CREATED, task.getExecutionState());
assertFalse(task.isCanceledOrFailed());
assertNull(task.getFailureCause());
task.registerExecutionListener(listener);
// go into the run method. we should switch to DEPLOYING, RUNNING, then
// FINISHED, and all should be good
task.run();
// verify final state
assertEquals(ExecutionState.FINISHED, task.getExecutionState());
assertFalse(task.isCanceledOrFailed());
assertNull(task.getFailureCause());
// verify listener messages
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.FINISHED, task, false);
// make sure that the TaskManager received an message to unregister the task
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testCancelRightAway() {
try {
Task task = createTask(TestInvokableCorrect.class);
task.cancelExecution();
assertEquals(ExecutionState.CANCELING, task.getExecutionState());
task.run();
// verify final state
assertEquals(ExecutionState.CANCELED, task.getExecutionState());
validateUnregisterTask(task.getExecutionId());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testFailExternallyRightAway() {
try {
Task task = createTask(TestInvokableCorrect.class);
task.failExternally(new Exception("fail externally"));
assertEquals(ExecutionState.FAILED, task.getExecutionState());
task.run();
// verify final state
assertEquals(ExecutionState.FAILED, task.getExecutionState());
validateUnregisterTask(task.getExecutionId());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testLibraryCacheRegistrationFailed() {
try {
Task task = createTask(TestInvokableCorrect.class, mock(LibraryCacheManager.class));
// task should be new and perfect
assertEquals(ExecutionState.CREATED, task.getExecutionState());
assertFalse(task.isCanceledOrFailed());
assertNull(task.getFailureCause());
task.registerExecutionListener(listener);
// should fail
task.run();
// verify final state
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertNotNull(task.getFailureCause());
assertTrue(task.getFailureCause().getMessage().contains("classloader"));
// verify listener messages
validateListenerMessage(ExecutionState.FAILED, task, true);
// make sure that the TaskManager received an message to unregister the task
validateUnregisterTask(task.getExecutionId());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testExecutionFailsInNetworkRegistration() {
try {
// mock a working library cache
LibraryCacheManager libCache = mock(LibraryCacheManager.class);
when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());
// mock a network manager that rejects registration
ResultPartitionManager partitionManager = mock(ResultPartitionManager.class);
ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class);
PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class);
Executor executor = mock(Executor.class);
NetworkEnvironment network = mock(NetworkEnvironment.class);
when(network.getResultPartitionManager()).thenReturn(partitionManager);
when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC);
doThrow(new RuntimeException("buffers")).when(network).registerTask(any(Task.class));
Task task = createTask(TestInvokableCorrect.class, libCache, network, consumableNotifier, partitionProducerStateChecker, executor);
task.registerExecutionListener(listener);
task.run();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("buffers"));
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testInvokableInstantiationFailed() {
try {
Task task = createTask(InvokableNonInstantiable.class);
task.registerExecutionListener(listener);
task.run();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("instantiate"));
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testExecutionFailsInInvoke() {
try {
Task task = createTask(InvokableWithExceptionInInvoke.class);
task.registerExecutionListener(listener);
task.run();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("test"));
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testFailWithWrappedException() {
try {
Task task = createTask(FailingInvokableWithChainedException.class);
task.registerExecutionListener(listener);
task.run();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
Throwable cause = task.getFailureCause();
assertTrue(cause instanceof IOException);
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testCancelDuringInvoke() {
try {
Task task = createTask(InvokableBlockingInInvoke.class);
task.registerExecutionListener(listener);
// run the task asynchronous
task.startTaskThread();
// wait till the task is in invoke
awaitLatch.await();
task.cancelExecution();
assertTrue(task.getExecutionState() == ExecutionState.CANCELING ||
task.getExecutionState() == ExecutionState.CANCELED);
task.getExecutingThread().join();
assertEquals(ExecutionState.CANCELED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertNull(task.getFailureCause());
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateCancelingAndCanceledListenerMessage(task);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testFailExternallyDuringInvoke() {
try {
Task task = createTask(InvokableBlockingInInvoke.class);
task.registerExecutionListener(listener);
// run the task asynchronous
task.startTaskThread();
// wait till the task is in regInOut
awaitLatch.await();
task.failExternally(new Exception("test"));
assertTrue(task.getExecutionState() == ExecutionState.FAILED);
task.getExecutingThread().join();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("test"));
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testCanceledAfterExecutionFailedInInvoke() {
try {
Task task = createTask(InvokableWithExceptionInInvoke.class);
task.registerExecutionListener(listener);
task.run();
// this should not overwrite the failure state
task.cancelExecution();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("test"));
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testExecutionFailesAfterCanceling() {
try {
Task task = createTask(InvokableWithExceptionOnTrigger.class);
task.registerExecutionListener(listener);
// run the task asynchronous
task.startTaskThread();
// wait till the task is in invoke
awaitLatch.await();
task.cancelExecution();
assertEquals(ExecutionState.CANCELING, task.getExecutionState());
// this causes an exception
triggerLatch.trigger();
task.getExecutingThread().join();
// we should still be in state canceled
assertEquals(ExecutionState.CANCELED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertNull(task.getFailureCause());
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateCancelingAndCanceledListenerMessage(task);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testExecutionFailsAfterTaskMarkedFailed() {
try {
Task task = createTask(InvokableWithExceptionOnTrigger.class);
task.registerExecutionListener(listener);
// run the task asynchronous
task.startTaskThread();
// wait till the task is in invoke
awaitLatch.await();
task.failExternally(new Exception("external"));
assertEquals(ExecutionState.FAILED, task.getExecutionState());
// this causes an exception
triggerLatch.trigger();
task.getExecutingThread().join();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("external"));
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testCancelTaskException() throws Exception {
final Task task = createTask(InvokableWithCancelTaskExceptionInInvoke.class);
// Cause CancelTaskException.
triggerLatch.trigger();
task.run();
assertEquals(ExecutionState.CANCELED, task.getExecutionState());
}
@Test
public void testCancelTaskExceptionAfterTaskMarkedFailed() throws Exception {
final Task task = createTask(InvokableWithCancelTaskExceptionInInvoke.class);
task.startTaskThread();
// Wait till the task is in invoke.
awaitLatch.await();
task.failExternally(new Exception("external"));
assertEquals(ExecutionState.FAILED, task.getExecutionState());
// Either we cause the CancelTaskException or the TaskCanceler
// by interrupting the invokable.
triggerLatch.trigger();
task.getExecutingThread().join();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("external"));
}
@Test
public void testOnPartitionStateUpdate() throws Exception {
IntermediateDataSetID resultId = new IntermediateDataSetID();
ResultPartitionID partitionId = new ResultPartitionID();
SingleInputGate inputGate = mock(SingleInputGate.class);
when(inputGate.getConsumedResultId()).thenReturn(resultId);
final Task task = createTask(InvokableBlockingInInvoke.class);
// Set the mock input gate
setInputGate(task, inputGate);
// Expected task state for each producer state
final Map<ExecutionState, ExecutionState> expected = new HashMap<>(ExecutionState.values().length);
// Fail the task for unexpected states
for (ExecutionState state : ExecutionState.values()) {
expected.put(state, ExecutionState.FAILED);
}
expected.put(ExecutionState.RUNNING, ExecutionState.RUNNING);
expected.put(ExecutionState.SCHEDULED, ExecutionState.RUNNING);
expected.put(ExecutionState.DEPLOYING, ExecutionState.RUNNING);
expected.put(ExecutionState.FINISHED, ExecutionState.RUNNING);
expected.put(ExecutionState.CANCELED, ExecutionState.CANCELING);
expected.put(ExecutionState.CANCELING, ExecutionState.CANCELING);
expected.put(ExecutionState.FAILED, ExecutionState.CANCELING);
for (ExecutionState state : ExecutionState.values()) {
setState(task, ExecutionState.RUNNING);
task.onPartitionStateUpdate(resultId, partitionId, state);
ExecutionState newTaskState = task.getExecutionState();
assertEquals(expected.get(state), newTaskState);
}
verify(inputGate, times(4)).retriggerPartitionRequest(eq(partitionId.getPartitionId()));
}
/**
* Tests the trigger partition state update future completions.
*/
@Test
public void testTriggerPartitionStateUpdate() throws Exception {
IntermediateDataSetID resultId = new IntermediateDataSetID();
ResultPartitionID partitionId = new ResultPartitionID();
LibraryCacheManager libCache = mock(LibraryCacheManager.class);
when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());
PartitionProducerStateChecker partitionChecker = mock(PartitionProducerStateChecker.class);
ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class);
NetworkEnvironment network = mock(NetworkEnvironment.class);
when(network.getResultPartitionManager()).thenReturn(mock(ResultPartitionManager.class));
when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC);
when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class)))
.thenReturn(mock(TaskKvStateRegistry.class));
createTask(InvokableBlockingInInvoke.class, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor());
// Test all branches of trigger partition state check
{
// Reset latches
createQueuesAndActors();
// PartitionProducerDisposedException
Task task = createTask(InvokableBlockingInInvoke.class, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor());
FlinkCompletableFuture<ExecutionState> promise = new FlinkCompletableFuture<>();
when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise);
task.triggerPartitionProducerStateCheck(task.getJobID(), resultId, partitionId);
promise.completeExceptionally(new PartitionProducerDisposedException(partitionId));
assertEquals(ExecutionState.CANCELING, task.getExecutionState());
}
{
// Reset latches
createQueuesAndActors();
// Any other exception
Task task = createTask(InvokableBlockingInInvoke.class, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor());
FlinkCompletableFuture<ExecutionState> promise = new FlinkCompletableFuture<>();
when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise);
task.triggerPartitionProducerStateCheck(task.getJobID(), resultId, partitionId);
promise.completeExceptionally(new RuntimeException("Any other exception"));
assertEquals(ExecutionState.FAILED, task.getExecutionState());
}
{
// Reset latches
createQueuesAndActors();
// TimeoutException handled special => retry
Task task = createTask(InvokableBlockingInInvoke.class, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor());
SingleInputGate inputGate = mock(SingleInputGate.class);
when(inputGate.getConsumedResultId()).thenReturn(resultId);
try {
task.startTaskThread();
awaitLatch.await();
setInputGate(task, inputGate);
FlinkCompletableFuture<ExecutionState> promise = new FlinkCompletableFuture<>();
when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise);
task.triggerPartitionProducerStateCheck(task.getJobID(), resultId, partitionId);
promise.completeExceptionally(new TimeoutException());
assertEquals(ExecutionState.RUNNING, task.getExecutionState());
verify(inputGate, times(1)).retriggerPartitionRequest(eq(partitionId.getPartitionId()));
} finally {
task.getExecutingThread().interrupt();
task.getExecutingThread().join();
}
}
{
// Reset latches
createQueuesAndActors();
// Success
Task task = createTask(InvokableBlockingInInvoke.class, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor());
SingleInputGate inputGate = mock(SingleInputGate.class);
when(inputGate.getConsumedResultId()).thenReturn(resultId);
try {
task.startTaskThread();
awaitLatch.await();
setInputGate(task, inputGate);
FlinkCompletableFuture<ExecutionState> promise = new FlinkCompletableFuture<>();
when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise);
task.triggerPartitionProducerStateCheck(task.getJobID(), resultId, partitionId);
promise.complete(ExecutionState.RUNNING);
assertEquals(ExecutionState.RUNNING, task.getExecutionState());
verify(inputGate, times(1)).retriggerPartitionRequest(eq(partitionId.getPartitionId()));
} finally {
task.getExecutingThread().interrupt();
task.getExecutingThread().join();
}
}
}
/**
* Tests that interrupt happens via watch dog if canceller is stuck in cancel.
* Task cancellation blocks the task canceller. Interrupt after cancel via
* cancellation watch dog.
*/
@Test
public void testWatchDogInterruptsTask() throws Exception {
Configuration config = new Configuration();
config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL.key(), 5);
config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT.key(), 60 * 1000);
Task task = createTask(InvokableBlockingInCancel.class, config);
task.startTaskThread();
awaitLatch.await();
task.cancelExecution();
task.getExecutingThread().join();
// No fatal error
for (Object msg : taskManagerMessages) {
assertFalse("Unexpected FatalError message", msg instanceof TaskManagerMessages.FatalError);
}
}
/**
* The invoke() method holds a lock (trigger awaitLatch after acquisition)
* and cancel cannot complete because it also tries to acquire the same lock.
* This is resolved by the watch dog, no fatal error.
*/
@Test
public void testInterruptableSharedLockInInvokeAndCancel() throws Exception {
Configuration config = new Configuration();
config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL, 5);
config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT, 50);
Task task = createTask(InvokableInterruptableSharedLockInInvokeAndCancel.class, config);
task.startTaskThread();
awaitLatch.await();
task.cancelExecution();
task.getExecutingThread().join();
// No fatal error
for (Object msg : taskManagerMessages) {
assertFalse("Unexpected FatalError message", msg instanceof TaskManagerMessages.FatalError);
}
}
/**
* The invoke() method blocks infinitely, but cancel() does not block. Only
* resolved by a fatal error.
*/
@Test
public void testFatalErrorAfterUninterruptibleInvoke() throws Exception {
Configuration config = new Configuration();
config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL, 5);
config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT, 50);
Task task = createTask(InvokableUninterruptibleBlockingInvoke.class, config);
try {
task.startTaskThread();
awaitLatch.await();
task.cancelExecution();
for (int i = 0; i < 10; i++) {
Object msg = taskManagerMessages.poll(1, TimeUnit.SECONDS);
if (msg instanceof TaskManagerMessages.FatalError) {
return; // success
}
}
fail("Did not receive expected task manager message");
} finally {
// Interrupt again to clean up Thread
cancelLatch.trigger();
task.getExecutingThread().interrupt();
task.getExecutingThread().join();
}
}
/**
* Tests that the task configuration is respected and overwritten by the execution config.
*/
@Test
public void testTaskConfig() throws Exception {
long interval = 28218123;
long timeout = interval + 19292;
Configuration config = new Configuration();
config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL, interval);
config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT, timeout);
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.setTaskCancellationInterval(interval + 1337);
executionConfig.setTaskCancellationTimeout(timeout - 1337);
Task task = createTask(InvokableBlockingInInvoke.class, config, executionConfig);
assertEquals(interval, task.getTaskCancellationInterval());
assertEquals(timeout, task.getTaskCancellationTimeout());
task.startTaskThread();
awaitLatch.await();
assertEquals(executionConfig.getTaskCancellationInterval(), task.getTaskCancellationInterval());
assertEquals(executionConfig.getTaskCancellationTimeout(), task.getTaskCancellationTimeout());
task.getExecutingThread().interrupt();
task.getExecutingThread().join();
}
// ------------------------------------------------------------------------
private void setInputGate(Task task, SingleInputGate inputGate) {
try {
Field f = Task.class.getDeclaredField("inputGates");
f.setAccessible(true);
f.set(task, new SingleInputGate[]{inputGate});
Map<IntermediateDataSetID, SingleInputGate> byId = new HashMap<>(1);
byId.put(inputGate.getConsumedResultId(), inputGate);
f = Task.class.getDeclaredField("inputGatesById");
f.setAccessible(true);
f.set(task, byId);
}
catch (Exception e) {
throw new RuntimeException("Modifying the task state failed", e);
}
}
private void setState(Task task, ExecutionState state) {
try {
Field f = Task.class.getDeclaredField("executionState");
f.setAccessible(true);
f.set(task, state);
}
catch (Exception e) {
throw new RuntimeException("Modifying the task state failed", e);
}
}
private Task createTask(Class<? extends AbstractInvokable> invokable) throws IOException {
return createTask(invokable, new Configuration(), new ExecutionConfig());
}
private Task createTask(Class<? extends AbstractInvokable> invokable, Configuration config) throws IOException {
LibraryCacheManager libCache = mock(LibraryCacheManager.class);
when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());
return createTask(invokable, libCache, config, new ExecutionConfig());
}
private Task createTask(Class<? extends AbstractInvokable> invokable, Configuration config, ExecutionConfig execConfig) throws IOException {
LibraryCacheManager libCache = mock(LibraryCacheManager.class);
when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());
return createTask(invokable, libCache, config, execConfig);
}
private Task createTask(
Class<? extends AbstractInvokable> invokable,
LibraryCacheManager libCache) throws IOException {
return createTask(invokable, libCache, new Configuration(), new ExecutionConfig());
}
private Task createTask(
Class<? extends AbstractInvokable> invokable,
LibraryCacheManager libCache,
Configuration config,
ExecutionConfig execConfig) throws IOException {
ResultPartitionManager partitionManager = mock(ResultPartitionManager.class);
ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class);
PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class);
Executor executor = mock(Executor.class);
NetworkEnvironment network = mock(NetworkEnvironment.class);
when(network.getResultPartitionManager()).thenReturn(partitionManager);
when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC);
when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class)))
.thenReturn(mock(TaskKvStateRegistry.class));
return createTask(invokable, libCache, network, consumableNotifier, partitionProducerStateChecker, executor, config, execConfig);
}
private Task createTask(
Class<? extends AbstractInvokable> invokable,
LibraryCacheManager libCache,
NetworkEnvironment networkEnvironment,
ResultPartitionConsumableNotifier consumableNotifier,
PartitionProducerStateChecker partitionProducerStateChecker,
Executor executor) throws IOException {
return createTask(invokable, libCache, networkEnvironment, consumableNotifier, partitionProducerStateChecker, executor, new Configuration(), new ExecutionConfig());
}
private Task createTask(
Class<? extends AbstractInvokable> invokable,
LibraryCacheManager libCache,
NetworkEnvironment networkEnvironment,
ResultPartitionConsumableNotifier consumableNotifier,
PartitionProducerStateChecker partitionProducerStateChecker,
Executor executor,
Configuration taskManagerConfig,
ExecutionConfig execConfig) throws IOException {
JobID jobId = new JobID();
JobVertexID jobVertexId = new JobVertexID();
ExecutionAttemptID executionAttemptId = new ExecutionAttemptID();
InputSplitProvider inputSplitProvider = new TaskInputSplitProvider(
jobManagerGateway,
jobId,
jobVertexId,
executionAttemptId,
new FiniteDuration(60, TimeUnit.SECONDS));
CheckpointResponder checkpointResponder = new ActorGatewayCheckpointResponder(jobManagerGateway);
SerializedValue<ExecutionConfig> serializedExecutionConfig = new SerializedValue<>(execConfig);
JobInformation jobInformation = new JobInformation(
jobId,
"Test Job",
serializedExecutionConfig,
new Configuration(),
Collections.<BlobKey>emptyList(),
Collections.<URL>emptyList());
TaskInformation taskInformation = new TaskInformation(
jobVertexId,
"Test Task",
1,
1,
invokable.getName(),
new Configuration());
TaskMetricGroup taskMetricGroup = mock(TaskMetricGroup.class);
when(taskMetricGroup.getIOMetricGroup()).thenReturn(mock(TaskIOMetricGroup.class));
return new Task(
jobInformation,
taskInformation,
executionAttemptId,
new AllocationID(),
0,
0,
Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
Collections.<InputGateDeploymentDescriptor>emptyList(),
0,
null,
mock(MemoryManager.class),
mock(IOManager.class),
networkEnvironment,
mock(BroadcastVariableManager.class),
taskManagerConnection,
inputSplitProvider,
checkpointResponder,
libCache,
mock(FileCache.class),
new TestingTaskManagerRuntimeInfo(taskManagerConfig),
taskMetricGroup,
consumableNotifier,
partitionProducerStateChecker,
executor);
}
// ------------------------------------------------------------------------
// Validation Methods
// ------------------------------------------------------------------------
private void validateUnregisterTask(ExecutionAttemptID id) {
try {
// we may have to wait for a bit to give the actors time to receive the message
// and put it into the queue
Object rawMessage = taskManagerMessages.take();
assertNotNull("There is no additional TaskManager message", rawMessage);
if (!(rawMessage instanceof TaskMessages.TaskInFinalState)) {
fail("TaskManager message is not 'UnregisterTask', but " + rawMessage.getClass());
}
TaskMessages.TaskInFinalState message = (TaskMessages.TaskInFinalState) rawMessage;
assertEquals(id, message.executionID());
}
catch (InterruptedException e) {
fail("interrupted");
}
}
private void validateTaskManagerStateChange(ExecutionState state, Task task, boolean hasError) {
try {
// we may have to wait for a bit to give the actors time to receive the message
// and put it into the queue
Object rawMessage = taskManagerMessages.take();
assertNotNull("There is no additional TaskManager message", rawMessage);
if (!(rawMessage instanceof TaskMessages.UpdateTaskExecutionState)) {
fail("TaskManager message is not 'UpdateTaskExecutionState', but " + rawMessage.getClass());
}
TaskMessages.UpdateTaskExecutionState message =
(TaskMessages.UpdateTaskExecutionState) rawMessage;
TaskExecutionState taskState = message.taskExecutionState();
assertEquals(task.getJobID(), taskState.getJobID());
assertEquals(task.getExecutionId(), taskState.getID());
assertEquals(state, taskState.getExecutionState());
if (hasError) {
assertNotNull(taskState.getError(getClass().getClassLoader()));
} else {
assertNull(taskState.getError(getClass().getClassLoader()));
}
}
catch (InterruptedException e) {
fail("interrupted");
}
}
private void validateListenerMessage(ExecutionState state, Task task, boolean hasError) {
try {
// we may have to wait for a bit to give the actors time to receive the message
// and put it into the queue
TaskMessages.UpdateTaskExecutionState message =
(TaskMessages.UpdateTaskExecutionState) listenerMessages.take();
assertNotNull("There is no additional listener message", message);
TaskExecutionState taskState = message.taskExecutionState();
assertEquals(task.getJobID(), taskState.getJobID());
assertEquals(task.getExecutionId(), taskState.getID());
assertEquals(state, taskState.getExecutionState());
if (hasError) {
assertNotNull(taskState.getError(getClass().getClassLoader()));
} else {
assertNull(taskState.getError(getClass().getClassLoader()));
}
}
catch (InterruptedException e) {
fail("interrupted");
}
}
private void validateCancelingAndCanceledListenerMessage(Task task) {
try {
// we may have to wait for a bit to give the actors time to receive the message
// and put it into the queue
TaskMessages.UpdateTaskExecutionState message1 =
(TaskMessages.UpdateTaskExecutionState) listenerMessages.take();
TaskMessages.UpdateTaskExecutionState message2 =
(TaskMessages.UpdateTaskExecutionState) listenerMessages.take();
assertNotNull("There is no additional listener message", message1);
assertNotNull("There is no additional listener message", message2);
TaskExecutionState taskState1 = message1.taskExecutionState();
TaskExecutionState taskState2 = message2.taskExecutionState();
assertEquals(task.getJobID(), taskState1.getJobID());
assertEquals(task.getJobID(), taskState2.getJobID());
assertEquals(task.getExecutionId(), taskState1.getID());
assertEquals(task.getExecutionId(), taskState2.getID());
ExecutionState state1 = taskState1.getExecutionState();
ExecutionState state2 = taskState2.getExecutionState();
// it may be (very rarely) that the following race happens:
// - OUTSIDE THREAD: call to cancel()
// - OUTSIDE THREAD: atomic state change from running to canceling
// - TASK THREAD: finishes, atomic change from canceling to canceled
// - TASK THREAD: send notification that state is canceled
// - OUTSIDE THREAD: send notification that state is canceling
// for that reason, we allow the notification messages in any order.
assertTrue( (state1 == ExecutionState.CANCELING && state2 == ExecutionState.CANCELED) ||
(state2 == ExecutionState.CANCELING && state1 == ExecutionState.CANCELED));
}
catch (InterruptedException e) {
fail("interrupted");
}
}
// --------------------------------------------------------------------------------------------
// Mock invokable code
// --------------------------------------------------------------------------------------------
public static final class TestInvokableCorrect extends AbstractInvokable {
@Override
public void invoke() {}
@Override
public void cancel() throws Exception {
fail("This should not be called");
}
}
public static final class InvokableWithExceptionInInvoke extends AbstractInvokable {
@Override
public void invoke() throws Exception {
throw new Exception("test");
}
}
public static final class InvokableWithExceptionOnTrigger extends AbstractInvokable {
@Override
public void invoke() {
awaitLatch.trigger();
// make sure that the interrupt call does not
// grab us out of the lock early
while (true) {
try {
triggerLatch.await();
break;
}
catch (InterruptedException e) {
// fall through the loop
}
}
throw new RuntimeException("test");
}
}
public static abstract class InvokableNonInstantiable extends AbstractInvokable {}
public static final class InvokableBlockingInInvoke extends AbstractInvokable {
@Override
public void invoke() throws Exception {
awaitLatch.trigger();
// block forever
synchronized (this) {
wait();
}
}
}
public static final class InvokableWithCancelTaskExceptionInInvoke extends AbstractInvokable {
@Override
public void invoke() throws Exception {
awaitLatch.trigger();
try {
triggerLatch.await();
}
catch (Throwable ignored) {}
throw new CancelTaskException();
}
}
public static final class InvokableInterruptableSharedLockInInvokeAndCancel extends AbstractInvokable {
private final Object lock = new Object();
@Override
public void invoke() throws Exception {
synchronized (lock) {
awaitLatch.trigger();
wait();
}
}
@Override
public void cancel() throws Exception {
synchronized (lock) {
cancelLatch.trigger();
}
}
}
public static final class InvokableBlockingInCancel extends AbstractInvokable {
@Override
public void invoke() throws Exception {
awaitLatch.trigger();
try {
cancelLatch.await();
synchronized (this) {
wait();
}
} catch (InterruptedException ignored) {
synchronized (this) {
notifyAll(); // notify all that are stuck in cancel
}
}
}
@Override
public void cancel() throws Exception {
synchronized (this) {
cancelLatch.trigger();
wait();
}
}
}
public static final class InvokableUninterruptibleBlockingInvoke extends AbstractInvokable {
@Override
public void invoke() throws Exception {
while (!cancelLatch.isTriggered()) {
try {
synchronized (this) {
awaitLatch.trigger();
wait();
}
} catch (InterruptedException ignored) {
}
}
}
@Override
public void cancel() throws Exception {
}
}
public static final class FailingInvokableWithChainedException extends AbstractInvokable {
@Override
public void invoke() throws Exception {
throw new TestWrappedException(new IOException("test"));
}
@Override
public void cancel() {}
}
// ------------------------------------------------------------------------
// test exceptions
// ------------------------------------------------------------------------
private static class TestWrappedException extends WrappingRuntimeException {
private static final long serialVersionUID = 1L;
public TestWrappedException(@Nonnull Throwable cause) {
super(cause);
}
}
}