/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.taskmanager; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.JobInformation; import org.apache.flink.runtime.executiongraph.TaskInformation; import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.util.SerializedValue; import org.junit.Before; import org.junit.Test; import java.net.URL; import java.util.Collections; import java.util.concurrent.Executor; import static org.junit.Assert.assertFalse; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class TaskAsyncCallTest { private static final int NUM_CALLS = 1000; private static OneShotLatch awaitLatch; private static OneShotLatch triggerLatch; @Before public void createQueuesAndActors() { awaitLatch = new OneShotLatch(); triggerLatch = new OneShotLatch(); } // ------------------------------------------------------------------------ // Tests // ------------------------------------------------------------------------ @Test public void testCheckpointCallsInOrder() { try { Task task = createTask(); task.startTaskThread(); awaitLatch.await(); for (int i = 1; i <= NUM_CALLS; i++) { task.triggerCheckpointBarrier(i, 156865867234L, CheckpointOptions.forFullCheckpoint()); } triggerLatch.await(); assertFalse(task.isCanceledOrFailed()); ExecutionState currentState = task.getExecutionState(); if (currentState != ExecutionState.RUNNING && currentState != ExecutionState.FINISHED) { fail("Task should be RUNNING or FINISHED, but is " + currentState); } task.cancelExecution(); task.getExecutingThread().join(); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testMixedAsyncCallsInOrder() { try { Task task = createTask(); task.startTaskThread(); awaitLatch.await(); for (int i = 1; i <= NUM_CALLS; i++) { task.triggerCheckpointBarrier(i, 156865867234L, CheckpointOptions.forFullCheckpoint()); task.notifyCheckpointComplete(i); } triggerLatch.await(); assertFalse(task.isCanceledOrFailed()); ExecutionState currentState = task.getExecutionState(); if (currentState != ExecutionState.RUNNING && currentState != ExecutionState.FINISHED) { fail("Task should be RUNNING or FINISHED, but is " + currentState); } task.cancelExecution(); task.getExecutingThread().join(); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } private static Task createTask() throws Exception { LibraryCacheManager libCache = mock(LibraryCacheManager.class); when(libCache.getClassLoader(any(JobID.class))).thenReturn(ClassLoader.getSystemClassLoader()); ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); Executor executor = mock(Executor.class); NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class); when(networkEnvironment.getResultPartitionManager()).thenReturn(partitionManager); when(networkEnvironment.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); TaskMetricGroup taskMetricGroup = mock(TaskMetricGroup.class); when(taskMetricGroup.getIOMetricGroup()).thenReturn(mock(TaskIOMetricGroup.class)); JobInformation jobInformation = new JobInformation( new JobID(), "Job Name", new SerializedValue<>(new ExecutionConfig()), new Configuration(), Collections.<BlobKey>emptyList(), Collections.<URL>emptyList()); TaskInformation taskInformation = new TaskInformation( new JobVertexID(), "Test Task", 1, 1, CheckpointsInOrderInvokable.class.getName(), new Configuration()); return new Task( jobInformation, taskInformation, new ExecutionAttemptID(), new AllocationID(), 0, 0, Collections.<ResultPartitionDeploymentDescriptor>emptyList(), Collections.<InputGateDeploymentDescriptor>emptyList(), 0, new TaskStateHandles(), mock(MemoryManager.class), mock(IOManager.class), networkEnvironment, mock(BroadcastVariableManager.class), mock(TaskManagerActions.class), mock(InputSplitProvider.class), mock(CheckpointResponder.class), libCache, mock(FileCache.class), new TestingTaskManagerRuntimeInfo(), taskMetricGroup, consumableNotifier, partitionProducerStateChecker, executor); } public static class CheckpointsInOrderInvokable extends AbstractInvokable implements StatefulTask { private volatile long lastCheckpointId = 0; private volatile Exception error; @Override public void invoke() throws Exception { awaitLatch.trigger(); // wait forever (until canceled) synchronized (this) { while (error == null && lastCheckpointId < NUM_CALLS) { wait(); } } triggerLatch.trigger(); if (error != null) { throw error; } } @Override public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {} @Override public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) { lastCheckpointId++; if (checkpointMetaData.getCheckpointId() == lastCheckpointId) { if (lastCheckpointId == NUM_CALLS) { triggerLatch.trigger(); } } else if (this.error == null) { this.error = new Exception("calls out of order"); synchronized (this) { notifyAll(); } } return true; } @Override public void triggerCheckpointOnBarrier(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions, CheckpointMetrics checkpointMetrics) throws Exception { throw new UnsupportedOperationException("Should not be called"); } @Override public void abortCheckpointOnBarrier(long checkpointId, Throwable cause) { throw new UnsupportedOperationException("Should not be called"); } @Override public void notifyCheckpointComplete(long checkpointId) { if (checkpointId != lastCheckpointId && this.error == null) { this.error = new Exception("calls out of order"); synchronized (this) { notifyAll(); } } } } }