/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.checkpoint; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.fs.Path; import org.apache.flink.runtime.concurrent.Executors; import org.apache.flink.runtime.concurrent.Future; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.executiongraph.Execution; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.KeyGroupRangeOffsets; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.testutils.CommonTestUtils; import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore; import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare; import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Random; import java.util.UUID; import java.util.concurrent.BlockingQueue; import java.util.concurrent.Executor; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.mockito.Mockito.withSettings; /** * Tests for the checkpoint coordinator. */ public class CheckpointCoordinatorTest { @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); @Test public void testCheckpointAbortsIfTriggerTasksAreNotExecuted() { try { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); // create some mock Execution vertices that receive the checkpoint trigger messages ExecutionVertex triggerVertex1 = mock(ExecutionVertex.class); ExecutionVertex triggerVertex2 = mock(ExecutionVertex.class); // create some mock Execution vertices that need to ack the checkpoint final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID2 = new ExecutionAttemptID(); ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptID1); ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptID2); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { triggerVertex1, triggerVertex2 }, new ExecutionVertex[] { ackVertex1, ackVertex2 }, new ExecutionVertex[] {}, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); // nothing should be happening assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // trigger the first checkpoint. this should not succeed assertFalse(coord.triggerCheckpoint(timestamp, false)); // still, nothing should be happening assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testCheckpointAbortsIfTriggerTasksAreFinished() { try { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); // create some mock Execution vertices that receive the checkpoint trigger messages final ExecutionAttemptID triggerAttemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID triggerAttemptID2 = new ExecutionAttemptID(); ExecutionVertex triggerVertex1 = mockExecutionVertex(triggerAttemptID1); JobVertexID jobVertexID2 = new JobVertexID(); ExecutionVertex triggerVertex2 = mockExecutionVertex( triggerAttemptID2, jobVertexID2, Lists.newArrayList(OperatorID.fromJobVertexID(jobVertexID2)), 1, 1, ExecutionState.FINISHED); // create some mock Execution vertices that need to ack the checkpoint final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID2 = new ExecutionAttemptID(); ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptID1); ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptID2); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { triggerVertex1, triggerVertex2 }, new ExecutionVertex[] { ackVertex1, ackVertex2 }, new ExecutionVertex[] {}, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); // nothing should be happening assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // trigger the first checkpoint. this should not succeed assertFalse(coord.triggerCheckpoint(timestamp, false)); // still, nothing should be happening assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testCheckpointAbortsIfAckTasksAreNotExecuted() { try { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); // create some mock Execution vertices that need to ack the checkpoint final ExecutionAttemptID triggerAttemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID triggerAttemptID2 = new ExecutionAttemptID(); ExecutionVertex triggerVertex1 = mockExecutionVertex(triggerAttemptID1); ExecutionVertex triggerVertex2 = mockExecutionVertex(triggerAttemptID2); // create some mock Execution vertices that receive the checkpoint trigger messages ExecutionVertex ackVertex1 = mock(ExecutionVertex.class); ExecutionVertex ackVertex2 = mock(ExecutionVertex.class); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { triggerVertex1, triggerVertex2 }, new ExecutionVertex[] { ackVertex1, ackVertex2 }, new ExecutionVertex[] {}, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); // nothing should be happening assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // trigger the first checkpoint. this should not succeed assertFalse(coord.triggerCheckpoint(timestamp, false)); // still, nothing should be happening assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } /** * This test triggers a checkpoint and then sends a decline checkpoint message from * one of the tasks. The expected behaviour is that said checkpoint is discarded and a new * checkpoint is triggered. */ @Test public void testTriggerAndDeclineCheckpointSimple() { try { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); // create some mock Execution vertices that receive the checkpoint trigger messages final ExecutionAttemptID attemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID attemptID2 = new ExecutionAttemptID(); ExecutionVertex vertex1 = mockExecutionVertex(attemptID1); ExecutionVertex vertex2 = mockExecutionVertex(attemptID2); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { vertex1, vertex2 }, new ExecutionVertex[] { vertex1, vertex2 }, new ExecutionVertex[] { vertex1, vertex2 }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // trigger the first checkpoint. this should succeed assertTrue(coord.triggerCheckpoint(timestamp, false)); // validate that we have a pending checkpoint assertEquals(1, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // we have one task scheduled that will cancel after timeout assertEquals(1, coord.getNumScheduledTasks()); long checkpointId = coord.getPendingCheckpoints().entrySet().iterator().next().getKey(); PendingCheckpoint checkpoint = coord.getPendingCheckpoints().get(checkpointId); assertNotNull(checkpoint); assertEquals(checkpointId, checkpoint.getCheckpointId()); assertEquals(timestamp, checkpoint.getCheckpointTimestamp()); assertEquals(jid, checkpoint.getJobId()); assertEquals(2, checkpoint.getNumberOfNonAcknowledgedTasks()); assertEquals(0, checkpoint.getNumberOfAcknowledgedTasks()); assertEquals(0, checkpoint.getOperatorStates().size()); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); // check that the vertices received the trigger checkpoint message verify(vertex1.getCurrentExecutionAttempt()).triggerCheckpoint(checkpointId, timestamp, CheckpointOptions.forFullCheckpoint()); verify(vertex2.getCurrentExecutionAttempt()).triggerCheckpoint(checkpointId, timestamp, CheckpointOptions.forFullCheckpoint()); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); // acknowledge from one of the tasks coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId)); assertEquals(1, checkpoint.getNumberOfAcknowledgedTasks()); assertEquals(1, checkpoint.getNumberOfNonAcknowledgedTasks()); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); // acknowledge the same task again (should not matter) coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId)); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); // decline checkpoint from the other task, this should cancel the checkpoint // and trigger a new one coord.receiveDeclineMessage(new DeclineCheckpoint(jid, attemptID1, checkpointId)); assertTrue(checkpoint.isDiscarded()); // the canceler is also removed assertEquals(0, coord.getNumScheduledTasks()); // validate that we have no new pending checkpoint assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // decline again, nothing should happen // decline from the other task, nothing should happen coord.receiveDeclineMessage(new DeclineCheckpoint(jid, attemptID1, checkpointId)); coord.receiveDeclineMessage(new DeclineCheckpoint(jid, attemptID2, checkpointId)); assertTrue(checkpoint.isDiscarded()); coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } /** * This test triggers two checkpoints and then sends a decline message from one of the tasks * for the first checkpoint. This should discard the first checkpoint while not triggering * a new checkpoint because a later checkpoint is already in progress. */ @Test public void testTriggerAndDeclineCheckpointComplex() { try { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); // create some mock Execution vertices that receive the checkpoint trigger messages final ExecutionAttemptID attemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID attemptID2 = new ExecutionAttemptID(); ExecutionVertex vertex1 = mockExecutionVertex(attemptID1); ExecutionVertex vertex2 = mockExecutionVertex(attemptID2); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { vertex1, vertex2 }, new ExecutionVertex[] { vertex1, vertex2 }, new ExecutionVertex[] { vertex1, vertex2 }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertEquals(0, coord.getNumScheduledTasks()); // trigger the first checkpoint. this should succeed assertTrue(coord.triggerCheckpoint(timestamp, false)); // trigger second checkpoint, should also succeed assertTrue(coord.triggerCheckpoint(timestamp + 2, false)); // validate that we have a pending checkpoint assertEquals(2, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertEquals(2, coord.getNumScheduledTasks()); Iterator<Map.Entry<Long, PendingCheckpoint>> it = coord.getPendingCheckpoints().entrySet().iterator(); long checkpoint1Id = it.next().getKey(); long checkpoint2Id = it.next().getKey(); PendingCheckpoint checkpoint1 = coord.getPendingCheckpoints().get(checkpoint1Id); PendingCheckpoint checkpoint2 = coord.getPendingCheckpoints().get(checkpoint2Id); assertNotNull(checkpoint1); assertEquals(checkpoint1Id, checkpoint1.getCheckpointId()); assertEquals(timestamp, checkpoint1.getCheckpointTimestamp()); assertEquals(jid, checkpoint1.getJobId()); assertEquals(2, checkpoint1.getNumberOfNonAcknowledgedTasks()); assertEquals(0, checkpoint1.getNumberOfAcknowledgedTasks()); assertEquals(0, checkpoint1.getOperatorStates().size()); assertFalse(checkpoint1.isDiscarded()); assertFalse(checkpoint1.isFullyAcknowledged()); assertNotNull(checkpoint2); assertEquals(checkpoint2Id, checkpoint2.getCheckpointId()); assertEquals(timestamp + 2, checkpoint2.getCheckpointTimestamp()); assertEquals(jid, checkpoint2.getJobId()); assertEquals(2, checkpoint2.getNumberOfNonAcknowledgedTasks()); assertEquals(0, checkpoint2.getNumberOfAcknowledgedTasks()); assertEquals(0, checkpoint2.getOperatorStates().size()); assertFalse(checkpoint2.isDiscarded()); assertFalse(checkpoint2.isFullyAcknowledged()); // check that the vertices received the trigger checkpoint message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpoint1Id), eq(timestamp), any(CheckpointOptions.class)); verify(vertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpoint1Id), eq(timestamp), any(CheckpointOptions.class)); } // check that the vertices received the trigger checkpoint message for the second checkpoint { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpoint2Id), eq(timestamp + 2), any(CheckpointOptions.class)); verify(vertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpoint2Id), eq(timestamp + 2), any(CheckpointOptions.class)); } // decline checkpoint from one of the tasks, this should cancel the checkpoint coord.receiveDeclineMessage(new DeclineCheckpoint(jid, attemptID1, checkpoint1Id)); assertTrue(checkpoint1.isDiscarded()); // validate that we have only one pending checkpoint left assertEquals(1, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertEquals(1, coord.getNumScheduledTasks()); // validate that it is the same second checkpoint from earlier long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey(); PendingCheckpoint checkpointNew = coord.getPendingCheckpoints().get(checkpointIdNew); assertEquals(checkpoint2Id, checkpointIdNew); assertNotNull(checkpointNew); assertEquals(checkpointIdNew, checkpointNew.getCheckpointId()); assertEquals(jid, checkpointNew.getJobId()); assertEquals(2, checkpointNew.getNumberOfNonAcknowledgedTasks()); assertEquals(0, checkpointNew.getNumberOfAcknowledgedTasks()); assertEquals(0, checkpointNew.getOperatorStates().size()); assertFalse(checkpointNew.isDiscarded()); assertFalse(checkpointNew.isFullyAcknowledged()); assertNotEquals(checkpoint1.getCheckpointId(), checkpointNew.getCheckpointId()); // decline again, nothing should happen // decline from the other task, nothing should happen coord.receiveDeclineMessage(new DeclineCheckpoint(jid, attemptID1, checkpoint1Id)); coord.receiveDeclineMessage(new DeclineCheckpoint(jid, attemptID2, checkpoint1Id)); assertTrue(checkpoint1.isDiscarded()); coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testTriggerAndConfirmSimpleCheckpoint() { try { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); // create some mock Execution vertices that receive the checkpoint trigger messages final ExecutionAttemptID attemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID attemptID2 = new ExecutionAttemptID(); ExecutionVertex vertex1 = mockExecutionVertex(attemptID1); ExecutionVertex vertex2 = mockExecutionVertex(attemptID2); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { vertex1, vertex2 }, new ExecutionVertex[] { vertex1, vertex2 }, new ExecutionVertex[] { vertex1, vertex2 }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertEquals(0, coord.getNumScheduledTasks()); // trigger the first checkpoint. this should succeed assertTrue(coord.triggerCheckpoint(timestamp, false)); // validate that we have a pending checkpoint assertEquals(1, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertEquals(1, coord.getNumScheduledTasks()); long checkpointId = coord.getPendingCheckpoints().entrySet().iterator().next().getKey(); PendingCheckpoint checkpoint = coord.getPendingCheckpoints().get(checkpointId); assertNotNull(checkpoint); assertEquals(checkpointId, checkpoint.getCheckpointId()); assertEquals(timestamp, checkpoint.getCheckpointTimestamp()); assertEquals(jid, checkpoint.getJobId()); assertEquals(2, checkpoint.getNumberOfNonAcknowledgedTasks()); assertEquals(0, checkpoint.getNumberOfAcknowledgedTasks()); assertEquals(0, checkpoint.getOperatorStates().size()); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); Map<OperatorID, OperatorState> operatorStates = checkpoint.getOperatorStates(); operatorStates.put(opID1, new SpyInjectingOperatorState( opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); operatorStates.put(opID2, new SpyInjectingOperatorState( opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex2.getMaxParallelism())); // check that the vertices received the trigger checkpoint message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class)); verify(vertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class)); } // acknowledge from one of the tasks AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1); OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(1, checkpoint.getNumberOfAcknowledgedTasks()); assertEquals(1, checkpoint.getNumberOfNonAcknowledgedTasks()); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class)); // acknowledge the same task again (should not matter) coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class)); // acknowledge the other task. coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed assertTrue(checkpoint.isDiscarded()); // the now we should have a completed checkpoint assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); // the canceler should be removed now assertEquals(0, coord.getNumScheduledTasks()); // validate that the subtasks states have registered their shared states. { verify(subtaskState1, times(1)).registerSharedStates(any(SharedStateRegistry.class)); verify(subtaskState2, times(1)).registerSharedStates(any(SharedStateRegistry.class)); } // validate that the relevant tasks got a confirmation message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class)); verify(vertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class)); } CompletedCheckpoint success = coord.getSuccessfulCheckpoints().get(0); assertEquals(jid, success.getJobId()); assertEquals(timestamp, success.getTimestamp()); assertEquals(checkpoint.getCheckpointId(), success.getCheckpointID()); assertEquals(2, success.getOperatorStates().size()); // --------------- // trigger another checkpoint and see that this one replaces the other checkpoint // --------------- final long timestampNew = timestamp + 7; coord.triggerCheckpoint(timestampNew, false); long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey(); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew)); subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew)); subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertEquals(0, coord.getNumScheduledTasks()); CompletedCheckpoint successNew = coord.getSuccessfulCheckpoints().get(0); assertEquals(jid, successNew.getJobId()); assertEquals(timestampNew, successNew.getTimestamp()); assertEquals(checkpointIdNew, successNew.getCheckpointID()); assertTrue(successNew.getOperatorStates().isEmpty()); // validate that the relevant tasks got a confirmation message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointIdNew), eq(timestampNew), any(CheckpointOptions.class)); verify(vertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointIdNew), eq(timestampNew), any(CheckpointOptions.class)); verify(vertex1.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointIdNew), eq(timestampNew)); verify(vertex2.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointIdNew), eq(timestampNew)); } coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testMultipleConcurrentCheckpoints() { try { final JobID jid = new JobID(); final long timestamp1 = System.currentTimeMillis(); final long timestamp2 = timestamp1 + 8617; // create some mock execution vertices final ExecutionAttemptID triggerAttemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID triggerAttemptID2 = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID2 = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID3 = new ExecutionAttemptID(); final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID(); ExecutionVertex triggerVertex1 = mockExecutionVertex(triggerAttemptID1); ExecutionVertex triggerVertex2 = mockExecutionVertex(triggerAttemptID2); ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptID1); ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptID2); ExecutionVertex ackVertex3 = mockExecutionVertex(ackAttemptID3); ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { triggerVertex1, triggerVertex2 }, new ExecutionVertex[] { ackVertex1, ackVertex2, ackVertex3 }, new ExecutionVertex[] { commitVertex }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, Executors.directExecutor()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // trigger the first checkpoint. this should succeed assertTrue(coord.triggerCheckpoint(timestamp1, false)); assertEquals(1, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); PendingCheckpoint pending1 = coord.getPendingCheckpoints().values().iterator().next(); long checkpointId1 = pending1.getCheckpointId(); // trigger messages should have been sent verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class)); verify(triggerVertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class)); // acknowledge one of the three tasks coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1)); // start the second checkpoint // trigger the first checkpoint. this should succeed assertTrue(coord.triggerCheckpoint(timestamp2, false)); assertEquals(2, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); PendingCheckpoint pending2; { Iterator<PendingCheckpoint> all = coord.getPendingCheckpoints().values().iterator(); PendingCheckpoint cc1 = all.next(); PendingCheckpoint cc2 = all.next(); pending2 = pending1 == cc1 ? cc2 : cc1; } long checkpointId2 = pending2.getCheckpointId(); // trigger messages should have been sent verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class)); verify(triggerVertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class)); // we acknowledge the remaining two tasks from the first // checkpoint and two tasks from the second checkpoint coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2)); // now, the first checkpoint should be confirmed assertEquals(1, coord.getNumberOfPendingCheckpoints()); assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertTrue(pending1.isDiscarded()); // the first confirm message should be out verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId1), eq(timestamp1)); // send the last remaining ack for the second checkpoint coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2)); // now, the second checkpoint should be confirmed assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(2, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertTrue(pending2.isDiscarded()); // the second commit message should be out verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2)); // validate the committed checkpoints List<CompletedCheckpoint> scs = coord.getSuccessfulCheckpoints(); CompletedCheckpoint sc1 = scs.get(0); assertEquals(checkpointId1, sc1.getCheckpointID()); assertEquals(timestamp1, sc1.getTimestamp()); assertEquals(jid, sc1.getJobId()); assertTrue(sc1.getOperatorStates().isEmpty()); CompletedCheckpoint sc2 = scs.get(1); assertEquals(checkpointId2, sc2.getCheckpointID()); assertEquals(timestamp2, sc2.getTimestamp()); assertEquals(jid, sc2.getJobId()); assertTrue(sc2.getOperatorStates().isEmpty()); coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testSuccessfulCheckpointSubsumesUnsuccessful() { try { final JobID jid = new JobID(); final long timestamp1 = System.currentTimeMillis(); final long timestamp2 = timestamp1 + 1552; // create some mock execution vertices final ExecutionAttemptID triggerAttemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID triggerAttemptID2 = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID2 = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID3 = new ExecutionAttemptID(); final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID(); ExecutionVertex triggerVertex1 = mockExecutionVertex(triggerAttemptID1); ExecutionVertex triggerVertex2 = mockExecutionVertex(triggerAttemptID2); ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptID1); ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptID2); ExecutionVertex ackVertex3 = mockExecutionVertex(ackAttemptID3); ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { triggerVertex1, triggerVertex2 }, new ExecutionVertex[] { ackVertex1, ackVertex2, ackVertex3 }, new ExecutionVertex[] { commitVertex }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(10), null, Executors.directExecutor()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // trigger the first checkpoint. this should succeed assertTrue(coord.triggerCheckpoint(timestamp1, false)); assertEquals(1, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); PendingCheckpoint pending1 = coord.getPendingCheckpoints().values().iterator().next(); long checkpointId1 = pending1.getCheckpointId(); // trigger messages should have been sent verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class)); verify(triggerVertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class)); OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId()); OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId()); OperatorID opID3 = OperatorID.fromJobVertexID(ackVertex3.getJobvertexId()); Map<OperatorID, OperatorState> operatorStates1 = pending1.getOperatorStates(); operatorStates1.put(opID1, new SpyInjectingOperatorState( opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); operatorStates1.put(opID2, new SpyInjectingOperatorState( opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); operatorStates1.put(opID3, new SpyInjectingOperatorState( opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism())); // acknowledge one of the three tasks coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class))); OperatorSubtaskState subtaskState1_2 = operatorStates1.get(opID2).getState(ackVertex2.getParallelSubtaskIndex()); // start the second checkpoint // trigger the first checkpoint. this should succeed assertTrue(coord.triggerCheckpoint(timestamp2, false)); assertEquals(2, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); PendingCheckpoint pending2; { Iterator<PendingCheckpoint> all = coord.getPendingCheckpoints().values().iterator(); PendingCheckpoint cc1 = all.next(); PendingCheckpoint cc2 = all.next(); pending2 = pending1 == cc1 ? cc2 : cc1; } long checkpointId2 = pending2.getCheckpointId(); Map<OperatorID, OperatorState> operatorStates2 = pending2.getOperatorStates(); operatorStates2.put(opID1, new SpyInjectingOperatorState( opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); operatorStates2.put(opID2, new SpyInjectingOperatorState( opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); operatorStates2.put(opID3, new SpyInjectingOperatorState( opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism())); // trigger messages should have been sent verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class)); verify(triggerVertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class)); // we acknowledge one more task from the first checkpoint and the second // checkpoint completely. The second checkpoint should then subsume the first checkpoint coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); OperatorSubtaskState subtaskState2_3 = operatorStates2.get(opID3).getState(ackVertex3.getParallelSubtaskIndex()); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); OperatorSubtaskState subtaskState2_1 = operatorStates2.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class))); OperatorSubtaskState subtaskState1_1 = operatorStates1.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); OperatorSubtaskState subtaskState2_2 = operatorStates2.get(opID2).getState(ackVertex2.getParallelSubtaskIndex()); // now, the second checkpoint should be confirmed, and the first discarded // actually both pending checkpoints are discarded, and the second has been transformed // into a successful checkpoint assertTrue(pending1.isDiscarded()); assertTrue(pending2.isDiscarded()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); // validate that all received subtask states in the first checkpoint have been discarded verify(subtaskState1_1, times(1)).discardState(); verify(subtaskState1_2, times(1)).discardState(); // validate that all subtask states in the second checkpoint are not discarded verify(subtaskState2_1, never()).discardState(); verify(subtaskState2_2, never()).discardState(); verify(subtaskState2_3, never()).discardState(); // validate the committed checkpoints List<CompletedCheckpoint> scs = coord.getSuccessfulCheckpoints(); CompletedCheckpoint success = scs.get(0); assertEquals(checkpointId2, success.getCheckpointID()); assertEquals(timestamp2, success.getTimestamp()); assertEquals(jid, success.getJobId()); assertEquals(3, success.getOperatorStates().size()); // the first confirm message should be out verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2)); // send the last remaining ack for the first checkpoint. This should not do anything SubtaskState subtaskState1_3 = mock(SubtaskState.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), subtaskState1_3)); verify(subtaskState1_3, times(1)).discardState(); coord.shutdown(JobStatus.FINISHED); // validate that the states in the second checkpoint have been discarded verify(subtaskState2_1, times(1)).discardState(); verify(subtaskState2_2, times(1)).discardState(); verify(subtaskState2_3, times(1)).discardState(); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testCheckpointTimeoutIsolated() { try { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); // create some mock execution vertices final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID2 = new ExecutionAttemptID(); final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID(); ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID); ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptID1); ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptID2); ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID); // set up the coordinator // the timeout for the checkpoint is a 200 milliseconds CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 200, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { triggerVertex }, new ExecutionVertex[] { ackVertex1, ackVertex2 }, new ExecutionVertex[] { commitVertex }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, Executors.directExecutor()); // trigger a checkpoint, partially acknowledged assertTrue(coord.triggerCheckpoint(timestamp, false)); assertEquals(1, coord.getNumberOfPendingCheckpoints()); PendingCheckpoint checkpoint = coord.getPendingCheckpoints().values().iterator().next(); assertFalse(checkpoint.isDiscarded()); OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId()); Map<OperatorID, OperatorState> operatorStates = checkpoint.getOperatorStates(); operatorStates.put(opID1, new SpyInjectingOperatorState( opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), mock(SubtaskState.class))); OperatorSubtaskState subtaskState = operatorStates.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); // wait until the checkpoint must have expired. // we check every 250 msecs conservatively for 5 seconds // to give even slow build servers a very good chance of completing this long deadline = System.currentTimeMillis() + 5000; do { Thread.sleep(250); } while (!checkpoint.isDiscarded() && coord.getNumberOfPendingCheckpoints() > 0 && System.currentTimeMillis() < deadline); assertTrue("Checkpoint was not canceled by the timeout", checkpoint.isDiscarded()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // validate that the received states have been discarded verify(subtaskState, times(1)).discardState(); // no confirm message must have been sent verify(commitVertex.getCurrentExecutionAttempt(), times(0)).notifyCheckpointComplete(anyLong(), anyLong()); coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testHandleMessagesForNonExistingCheckpoints() { try { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); // create some mock execution vertices and trigger some checkpoint final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID2 = new ExecutionAttemptID(); final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID(); ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID); ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptID1); ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptID2); ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID); CheckpointCoordinator coord = new CheckpointCoordinator( jid, 200000, 200000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { triggerVertex }, new ExecutionVertex[] { ackVertex1, ackVertex2 }, new ExecutionVertex[] { commitVertex }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, Executors.directExecutor()); assertTrue(coord.triggerCheckpoint(timestamp, false)); long checkpointId = coord.getPendingCheckpoints().keySet().iterator().next(); // send some messages that do not belong to either the job or the any // of the vertices that need to be acknowledged. // non of the messages should throw an exception // wrong job id coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), ackAttemptID1, checkpointId)); // unknown checkpoint coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, 1L)); // unknown ack vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, new ExecutionAttemptID(), checkpointId)); coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } /** * Tests that late acknowledge checkpoint messages are properly cleaned up. Furthermore it tests * that unknown checkpoint messages for the same job a are cleaned up as well. In contrast * checkpointing messages from other jobs should not be touched. A late acknowledge * message is an acknowledge message which arrives after the checkpoint has been declined. * * @throws Exception */ @Test public void testStateCleanupForLateOrUnknownMessages() throws Exception { final JobID jobId = new JobID(); final ExecutionAttemptID triggerAttemptId = new ExecutionAttemptID(); final ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptId); final ExecutionAttemptID ackAttemptId1 = new ExecutionAttemptID(); final ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptId1); final ExecutionAttemptID ackAttemptId2 = new ExecutionAttemptID(); final ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptId2); final long timestamp = 1L; CheckpointCoordinator coord = new CheckpointCoordinator( jobId, 20000L, 20000L, 0L, 1, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { triggerVertex }, new ExecutionVertex[] {triggerVertex, ackVertex1, ackVertex2}, new ExecutionVertex[0], new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); assertTrue(coord.triggerCheckpoint(timestamp, false)); assertEquals(1, coord.getNumberOfPendingCheckpoints()); PendingCheckpoint pendingCheckpoint = coord.getPendingCheckpoints().values().iterator().next(); long checkpointId = pendingCheckpoint.getCheckpointId(); OperatorID opIDtrigger = OperatorID.fromJobVertexID(triggerVertex.getJobvertexId()); OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId()); OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId()); Map<OperatorID, OperatorState> operatorStates = pendingCheckpoint.getOperatorStates(); operatorStates.put(opIDtrigger, new SpyInjectingOperatorState( opIDtrigger, triggerVertex.getTotalNumberOfParallelSubtasks(), triggerVertex.getMaxParallelism())); operatorStates.put(opID1, new SpyInjectingOperatorState( opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); operatorStates.put(opID2, new SpyInjectingOperatorState( opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); // acknowledge the first trigger vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); OperatorSubtaskState storedTriggerSubtaskState = operatorStates.get(opIDtrigger).getState(triggerVertex.getParallelSubtaskIndex()); // verify that the subtask state has not been discarded verify(storedTriggerSubtaskState, never()).discardState(); SubtaskState unknownSubtaskState = mock(SubtaskState.class); // receive an acknowledge message for an unknown vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState)); // we should discard acknowledge messages from an unknown vertex belonging to our job verify(unknownSubtaskState, times(1)).discardState(); SubtaskState differentJobSubtaskState = mock(SubtaskState.class); // receive an acknowledge message from an unknown job coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState)); // we should not interfere with different jobs verify(differentJobSubtaskState, never()).discardState(); // duplicate acknowledge message for the trigger vertex SubtaskState triggerSubtaskState = mock(SubtaskState.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState)); // duplicate acknowledge messages for a known vertex should not trigger discarding the state verify(triggerSubtaskState, never()).discardState(); // let the checkpoint fail at the first ack vertex reset(storedTriggerSubtaskState); coord.receiveDeclineMessage(new DeclineCheckpoint(jobId, ackAttemptId1, checkpointId)); assertTrue(pendingCheckpoint.isDiscarded()); // check that we've cleaned up the already acknowledged state verify(storedTriggerSubtaskState, times(1)).discardState(); SubtaskState ackSubtaskState = mock(SubtaskState.class); // late acknowledge message from the second ack vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new CheckpointMetrics(), ackSubtaskState)); // check that we also cleaned up this state verify(ackSubtaskState, times(1)).discardState(); // receive an acknowledge message from an unknown job reset(differentJobSubtaskState); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState)); // we should not interfere with different jobs verify(differentJobSubtaskState, never()).discardState(); SubtaskState unknownSubtaskState2 = mock(SubtaskState.class); // receive an acknowledge message for an unknown vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState2)); // we should discard acknowledge messages from an unknown vertex belonging to our job verify(unknownSubtaskState2, times(1)).discardState(); } @Test public void testPeriodicTriggering() { try { final JobID jid = new JobID(); final long start = System.currentTimeMillis(); // create some mock execution vertices and trigger some checkpoint final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID = new ExecutionAttemptID(); final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID(); ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID); ExecutionVertex ackVertex = mockExecutionVertex(ackAttemptID); ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID); final AtomicInteger numCalls = new AtomicInteger(); final Execution execution = triggerVertex.getCurrentExecutionAttempt(); doAnswer(new Answer<Void>() { private long lastId = -1; private long lastTs = -1; @Override public Void answer(InvocationOnMock invocation) throws Throwable { long id = (Long) invocation.getArguments()[0]; long ts = (Long) invocation.getArguments()[1]; assertTrue(id > lastId); assertTrue(ts >= lastTs); assertTrue(ts >= start); lastId = id; lastTs = ts; numCalls.incrementAndGet(); return null; } }).when(execution).triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class)); CheckpointCoordinator coord = new CheckpointCoordinator( jid, 10, // periodic interval is 10 ms 200000, // timeout is very long (200 s) 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { triggerVertex }, new ExecutionVertex[] { ackVertex }, new ExecutionVertex[] { commitVertex }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, Executors.directExecutor()); coord.startCheckpointScheduler(); long timeout = System.currentTimeMillis() + 60000; do { Thread.sleep(20); } while (timeout > System.currentTimeMillis() && numCalls.get() < 5); assertTrue(numCalls.get() >= 5); coord.stopCheckpointScheduler(); // for 400 ms, no further calls may come. // there may be the case that one trigger was fired and about to // acquire the lock, such that after cancelling it will still do // the remainder of its work int numCallsSoFar = numCalls.get(); Thread.sleep(400); assertTrue(numCallsSoFar == numCalls.get() || numCallsSoFar+1 == numCalls.get()); // start another sequence of periodic scheduling numCalls.set(0); coord.startCheckpointScheduler(); timeout = System.currentTimeMillis() + 60000; do { Thread.sleep(20); } while (timeout > System.currentTimeMillis() && numCalls.get() < 5); assertTrue(numCalls.get() >= 5); coord.stopCheckpointScheduler(); // for 400 ms, no further calls may come // there may be the case that one trigger was fired and about to // acquire the lock, such that after cancelling it will still do // the remainder of its work numCallsSoFar = numCalls.get(); Thread.sleep(400); assertTrue(numCallsSoFar == numCalls.get() || numCallsSoFar + 1 == numCalls.get()); coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } /** * This test verified that after a completed checkpoint a certain time has passed before * another is triggered. */ @Test public void testMinTimeBetweenCheckpointsInterval() throws Exception { final JobID jid = new JobID(); // create some mock execution vertices and trigger some checkpoint final ExecutionAttemptID attemptID = new ExecutionAttemptID(); final ExecutionVertex vertex = mockExecutionVertex(attemptID); final Execution executionAttempt = vertex.getCurrentExecutionAttempt(); final BlockingQueue<Long> triggerCalls = new LinkedBlockingQueue<>(); doAnswer(new Answer<Void>() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { triggerCalls.add((Long) invocation.getArguments()[0]); return null; } }).when(executionAttempt).triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class)); final long delay = 50; final CheckpointCoordinator coord = new CheckpointCoordinator( jid, 2, // periodic interval is 2 ms 200_000, // timeout is very long (200 s) delay, // 50 ms delay between checkpoints 1, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { vertex }, new ExecutionVertex[] { vertex }, new ExecutionVertex[] { vertex }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), "dummy-path", Executors.directExecutor()); try { coord.startCheckpointScheduler(); // wait until the first checkpoint was triggered Long firstCallId = triggerCalls.take(); assertEquals(1L, firstCallId.longValue()); AcknowledgeCheckpoint ackMsg = new AcknowledgeCheckpoint(jid, attemptID, 1L); // tell the coordinator that the checkpoint is done final long ackTime = System.nanoTime(); coord.receiveAcknowledgeMessage(ackMsg); // wait until the next checkpoint is triggered Long nextCallId = triggerCalls.take(); final long nextCheckpointTime = System.nanoTime(); assertEquals(2L, nextCallId.longValue()); final long delayMillis = (nextCheckpointTime - ackTime) / 1_000_000; // we need to add one ms here to account for rounding errors if (delayMillis + 1 < delay) { fail("checkpoint came too early: delay was " + delayMillis + " but should have been at least " + delay); } } finally { coord.stopCheckpointScheduler(); coord.shutdown(JobStatus.FINISHED); } } @Test public void testMaxConcurrentAttempts1() { testMaxConcurrentAttempts(1); } @Test public void testMaxConcurrentAttempts2() { testMaxConcurrentAttempts(2); } @Test public void testMaxConcurrentAttempts5() { testMaxConcurrentAttempts(5); } @Test public void testTriggerAndConfirmSimpleSavepoint() throws Exception { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); // create some mock Execution vertices that receive the checkpoint trigger messages final ExecutionAttemptID attemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID attemptID2 = new ExecutionAttemptID(); ExecutionVertex vertex1 = mockExecutionVertex(attemptID1); ExecutionVertex vertex2 = mockExecutionVertex(attemptID2); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { vertex1, vertex2 }, new ExecutionVertex[] { vertex1, vertex2 }, new ExecutionVertex[] { vertex1, vertex2 }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // trigger the first checkpoint. this should succeed String savepointDir = tmpFolder.newFolder().getAbsolutePath(); Future<CompletedCheckpoint> savepointFuture = coord.triggerSavepoint(timestamp, savepointDir); assertFalse(savepointFuture.isDone()); // validate that we have a pending savepoint assertEquals(1, coord.getNumberOfPendingCheckpoints()); long checkpointId = coord.getPendingCheckpoints().entrySet().iterator().next().getKey(); PendingCheckpoint pending = coord.getPendingCheckpoints().get(checkpointId); assertNotNull(pending); assertEquals(checkpointId, pending.getCheckpointId()); assertEquals(timestamp, pending.getCheckpointTimestamp()); assertEquals(jid, pending.getJobId()); assertEquals(2, pending.getNumberOfNonAcknowledgedTasks()); assertEquals(0, pending.getNumberOfAcknowledgedTasks()); assertEquals(0, pending.getOperatorStates().size()); assertFalse(pending.isDiscarded()); assertFalse(pending.isFullyAcknowledged()); assertFalse(pending.canBeSubsumed()); OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); Map<OperatorID, OperatorState> operatorStates = pending.getOperatorStates(); operatorStates.put(opID1, new SpyInjectingOperatorState( opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); operatorStates.put(opID2, new SpyInjectingOperatorState( opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); // acknowledge from one of the tasks AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2); OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(1, pending.getNumberOfAcknowledgedTasks()); assertEquals(1, pending.getNumberOfNonAcknowledgedTasks()); assertFalse(pending.isDiscarded()); assertFalse(pending.isFullyAcknowledged()); assertFalse(savepointFuture.isDone()); // acknowledge the same task again (should not matter) coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2); assertFalse(pending.isDiscarded()); assertFalse(pending.isFullyAcknowledged()); assertFalse(savepointFuture.isDone()); // acknowledge the other task. coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed assertTrue(pending.isDiscarded()); assertTrue(savepointFuture.isDone()); // the now we should have a completed checkpoint assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); // validate that the relevant tasks got a confirmation message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId), eq(timestamp)); verify(vertex2.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId), eq(timestamp)); } // validate that the shared states are registered { verify(subtaskState1, times(1)).registerSharedStates(any(SharedStateRegistry.class)); verify(subtaskState2, times(1)).registerSharedStates(any(SharedStateRegistry.class)); } CompletedCheckpoint success = coord.getSuccessfulCheckpoints().get(0); assertEquals(jid, success.getJobId()); assertEquals(timestamp, success.getTimestamp()); assertEquals(pending.getCheckpointId(), success.getCheckpointID()); assertEquals(2, success.getOperatorStates().size()); // --------------- // trigger another checkpoint and see that this one replaces the other checkpoint // --------------- final long timestampNew = timestamp + 7; savepointFuture = coord.triggerSavepoint(timestampNew, savepointDir); assertFalse(savepointFuture.isDone()); long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey(); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew)); subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); CompletedCheckpoint successNew = coord.getSuccessfulCheckpoints().get(0); assertEquals(jid, successNew.getJobId()); assertEquals(timestampNew, successNew.getTimestamp()); assertEquals(checkpointIdNew, successNew.getCheckpointID()); assertTrue(successNew.getOperatorStates().isEmpty()); assertTrue(savepointFuture.isDone()); // validate that the first savepoint does not discard its private states. verify(subtaskState1, never()).discardState(); verify(subtaskState2, never()).discardState(); // validate that the relevant tasks got a confirmation message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointIdNew), eq(timestampNew), any(CheckpointOptions.class)); verify(vertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointIdNew), eq(timestampNew), any(CheckpointOptions.class)); verify(vertex1.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointIdNew), eq(timestampNew)); verify(vertex2.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointIdNew), eq(timestampNew)); } coord.shutdown(JobStatus.FINISHED); } /** * Triggers a savepoint and two checkpoints. The second checkpoint completes * and subsumes the first checkpoint, but not the first savepoint. Then we * trigger another checkpoint and savepoint. The 2nd savepoint completes and * subsumes the last checkpoint, but not the first savepoint. */ @Test public void testSavepointsAreNotSubsumed() throws Exception { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); // create some mock Execution vertices that receive the checkpoint trigger messages final ExecutionAttemptID attemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID attemptID2 = new ExecutionAttemptID(); ExecutionVertex vertex1 = mockExecutionVertex(attemptID1); ExecutionVertex vertex2 = mockExecutionVertex(attemptID2); StandaloneCheckpointIDCounter counter = new StandaloneCheckpointIDCounter(); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { vertex1, vertex2 }, new ExecutionVertex[] { vertex1, vertex2 }, new ExecutionVertex[] { vertex1, vertex2 }, counter, new StandaloneCompletedCheckpointStore(10), null, Executors.directExecutor()); String savepointDir = tmpFolder.newFolder().getAbsolutePath(); // Trigger savepoint and checkpoint Future<CompletedCheckpoint> savepointFuture1 = coord.triggerSavepoint(timestamp, savepointDir); long savepointId1 = counter.getLast(); assertEquals(1, coord.getNumberOfPendingCheckpoints()); assertTrue(coord.triggerCheckpoint(timestamp + 1, false)); assertEquals(2, coord.getNumberOfPendingCheckpoints()); assertTrue(coord.triggerCheckpoint(timestamp + 2, false)); long checkpointId2 = counter.getLast(); assertEquals(3, coord.getNumberOfPendingCheckpoints()); // 2nd checkpoint should subsume the 1st checkpoint, but not the savepoint coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId2)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId2)); assertEquals(1, coord.getNumberOfPendingCheckpoints()); assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertFalse(coord.getPendingCheckpoints().get(savepointId1).isDiscarded()); assertFalse(savepointFuture1.isDone()); assertTrue(coord.triggerCheckpoint(timestamp + 3, false)); assertEquals(2, coord.getNumberOfPendingCheckpoints()); Future<CompletedCheckpoint> savepointFuture2 = coord.triggerSavepoint(timestamp + 4, savepointDir); long savepointId2 = counter.getLast(); assertEquals(3, coord.getNumberOfPendingCheckpoints()); // 2nd savepoint should subsume the last checkpoint, but not the 1st savepoint coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, savepointId2)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, savepointId2)); assertEquals(1, coord.getNumberOfPendingCheckpoints()); assertEquals(2, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertFalse(coord.getPendingCheckpoints().get(savepointId1).isDiscarded()); assertFalse(savepointFuture1.isDone()); assertTrue(savepointFuture2.isDone()); // Ack first savepoint coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, savepointId1)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, savepointId1)); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(3, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertTrue(savepointFuture1.isDone()); } private void testMaxConcurrentAttempts(int maxConcurrentAttempts) { try { final JobID jid = new JobID(); // create some mock execution vertices and trigger some checkpoint final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID = new ExecutionAttemptID(); final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID(); ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID); ExecutionVertex ackVertex = mockExecutionVertex(ackAttemptID); ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID); final AtomicInteger numCalls = new AtomicInteger(); final Execution execution = triggerVertex.getCurrentExecutionAttempt(); doAnswer(new Answer<Void>() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { numCalls.incrementAndGet(); return null; } }).when(execution).triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class)); doAnswer(new Answer<Void>() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { numCalls.incrementAndGet(); return null; } }).when(execution).notifyCheckpointComplete(anyLong(), anyLong()); CheckpointCoordinator coord = new CheckpointCoordinator( jid, 10, // periodic interval is 10 ms 200000, // timeout is very long (200 s) 0L, // no extra delay maxConcurrentAttempts, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { triggerVertex }, new ExecutionVertex[] { ackVertex }, new ExecutionVertex[] { commitVertex }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, Executors.directExecutor()); coord.startCheckpointScheduler(); // after a while, there should be exactly as many checkpoints // as concurrently permitted long now = System.currentTimeMillis(); long timeout = now + 60000; long minDuration = now + 100; do { Thread.sleep(20); } while ((now = System.currentTimeMillis()) < minDuration || (numCalls.get() < maxConcurrentAttempts && now < timeout)); assertEquals(maxConcurrentAttempts, numCalls.get()); verify(triggerVertex.getCurrentExecutionAttempt(), times(maxConcurrentAttempts)) .triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class)); // now, once we acknowledge one checkpoint, it should trigger the next one coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID, 1L)); // this should have immediately triggered a new checkpoint now = System.currentTimeMillis(); timeout = now + 60000; do { Thread.sleep(20); } while (numCalls.get() < maxConcurrentAttempts + 1 && now < timeout); assertEquals(maxConcurrentAttempts + 1, numCalls.get()); // no further checkpoints should happen Thread.sleep(200); assertEquals(maxConcurrentAttempts + 1, numCalls.get()); coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testMaxConcurrentAttempsWithSubsumption() { try { final int maxConcurrentAttempts = 2; final JobID jid = new JobID(); // create some mock execution vertices and trigger some checkpoint final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID = new ExecutionAttemptID(); final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID(); ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID); ExecutionVertex ackVertex = mockExecutionVertex(ackAttemptID); ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID); CheckpointCoordinator coord = new CheckpointCoordinator( jid, 10, // periodic interval is 10 ms 200000, // timeout is very long (200 s) 0L, // no extra delay maxConcurrentAttempts, // max two concurrent checkpoints ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { triggerVertex }, new ExecutionVertex[] { ackVertex }, new ExecutionVertex[] { commitVertex }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, Executors.directExecutor()); coord.startCheckpointScheduler(); // after a while, there should be exactly as many checkpoints // as concurrently permitted long now = System.currentTimeMillis(); long timeout = now + 60000; long minDuration = now + 100; do { Thread.sleep(20); } while ((now = System.currentTimeMillis()) < minDuration || (coord.getNumberOfPendingCheckpoints() < maxConcurrentAttempts && now < timeout)); // validate that the pending checkpoints are there assertEquals(maxConcurrentAttempts, coord.getNumberOfPendingCheckpoints()); assertNotNull(coord.getPendingCheckpoints().get(1L)); assertNotNull(coord.getPendingCheckpoints().get(2L)); // now we acknowledge the second checkpoint, which should subsume the first checkpoint // and allow two more checkpoints to be triggered // now, once we acknowledge one checkpoint, it should trigger the next one coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID, 2L)); // after a while, there should be the new checkpoints final long newTimeout = System.currentTimeMillis() + 60000; do { Thread.sleep(20); } while (coord.getPendingCheckpoints().get(4L) == null && System.currentTimeMillis() < newTimeout); // do the final check assertEquals(maxConcurrentAttempts, coord.getNumberOfPendingCheckpoints()); assertNotNull(coord.getPendingCheckpoints().get(3L)); assertNotNull(coord.getPendingCheckpoints().get(4L)); coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testPeriodicSchedulingWithInactiveTasks() { try { final JobID jid = new JobID(); // create some mock execution vertices and trigger some checkpoint final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID(); final ExecutionAttemptID ackAttemptID = new ExecutionAttemptID(); final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID(); ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID); ExecutionVertex ackVertex = mockExecutionVertex(ackAttemptID); ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID); final AtomicReference<ExecutionState> currentState = new AtomicReference<>(ExecutionState.CREATED); when(triggerVertex.getCurrentExecutionAttempt().getState()).thenAnswer( new Answer<ExecutionState>() { @Override public ExecutionState answer(InvocationOnMock invocation){ return currentState.get(); } }); CheckpointCoordinator coord = new CheckpointCoordinator( jid, 10, // periodic interval is 10 ms 200000, // timeout is very long (200 s) 0L, // no extra delay 2, // max two concurrent checkpoints ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { triggerVertex }, new ExecutionVertex[] { ackVertex }, new ExecutionVertex[] { commitVertex }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, Executors.directExecutor()); coord.startCheckpointScheduler(); // no checkpoint should have started so far Thread.sleep(200); assertEquals(0, coord.getNumberOfPendingCheckpoints()); // now move the state to RUNNING currentState.set(ExecutionState.RUNNING); // the coordinator should start checkpointing now final long timeout = System.currentTimeMillis() + 10000; do { Thread.sleep(20); } while (System.currentTimeMillis() < timeout && coord.getNumberOfPendingCheckpoints() == 0); assertTrue(coord.getNumberOfPendingCheckpoints() > 0); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } /** * Tests that the savepoints can be triggered concurrently. */ @Test public void testConcurrentSavepoints() throws Exception { JobID jobId = new JobID(); final ExecutionAttemptID attemptID1 = new ExecutionAttemptID(); ExecutionVertex vertex1 = mockExecutionVertex(attemptID1); StandaloneCheckpointIDCounter checkpointIDCounter = new StandaloneCheckpointIDCounter(); CheckpointCoordinator coord = new CheckpointCoordinator( jobId, 100000, 200000, 0L, 1, // max one checkpoint at a time => should not affect savepoints ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { vertex1 }, new ExecutionVertex[] { vertex1 }, new ExecutionVertex[] { vertex1 }, checkpointIDCounter, new StandaloneCompletedCheckpointStore(2), null, Executors.directExecutor()); List<Future<CompletedCheckpoint>> savepointFutures = new ArrayList<>(); int numSavepoints = 5; String savepointDir = tmpFolder.newFolder().getAbsolutePath(); // Trigger savepoints for (int i = 0; i < numSavepoints; i++) { savepointFutures.add(coord.triggerSavepoint(i, savepointDir)); } // After triggering multiple savepoints, all should in progress for (Future<CompletedCheckpoint> savepointFuture : savepointFutures) { assertFalse(savepointFuture.isDone()); } // ACK all savepoints long checkpointId = checkpointIDCounter.getLast(); for (int i = 0; i < numSavepoints; i++, checkpointId--) { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, attemptID1, checkpointId)); } // After ACKs, all should be completed for (Future<CompletedCheckpoint> savepointFuture : savepointFutures) { assertTrue(savepointFuture.isDone()); } } /** * Tests that no minimum delay between savepoints is enforced. */ @Test public void testMinDelayBetweenSavepoints() throws Exception { JobID jobId = new JobID(); final ExecutionAttemptID attemptID1 = new ExecutionAttemptID(); ExecutionVertex vertex1 = mockExecutionVertex(attemptID1); CheckpointCoordinator coord = new CheckpointCoordinator( jobId, 100000, 200000, 100000000L, // very long min delay => should not affect savepoints 1, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { vertex1 }, new ExecutionVertex[] { vertex1 }, new ExecutionVertex[] { vertex1 }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, Executors.directExecutor()); String savepointDir = tmpFolder.newFolder().getAbsolutePath(); Future<CompletedCheckpoint> savepoint0 = coord.triggerSavepoint(0, savepointDir); assertFalse("Did not trigger savepoint", savepoint0.isDone()); Future<CompletedCheckpoint> savepoint1 = coord.triggerSavepoint(1, savepointDir); assertFalse("Did not trigger savepoint", savepoint1.isDone()); } /** * Tests that the checkpointed partitioned and non-partitioned state is assigned properly to * the {@link Execution} upon recovery. * * @throws Exception */ @Test public void testRestoreLatestCheckpointedState() throws Exception { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); final JobVertexID jobVertexID1 = new JobVertexID(); final JobVertexID jobVertexID2 = new JobVertexID(); int parallelism1 = 3; int parallelism2 = 2; int maxParallelism1 = 42; int maxParallelism2 = 13; final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex( jobVertexID1, parallelism1, maxParallelism1); final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex( jobVertexID2, parallelism2, maxParallelism2); List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2); allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices())); allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices())); ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]); CompletedCheckpointStore store = new RecoverableCompletedCheckpointStore(); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), arrayExecutionVertices, arrayExecutionVertices, arrayExecutionVertices, new StandaloneCheckpointIDCounter(), store, null, Executors.directExecutor()); // trigger the checkpoint coord.triggerCheckpoint(timestamp, false); assertTrue(coord.getPendingCheckpoints().keySet().size() == 1); long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); PendingCheckpoint pending = coord.getPendingCheckpoints().get(checkpointId); OperatorID opID1 = OperatorID.fromJobVertexID(jobVertexID1); OperatorID opID2 = OperatorID.fromJobVertexID(jobVertexID2); Map<OperatorID, OperatorState> operatorStates = pending.getOperatorStates(); operatorStates.put(opID1, new SpyInjectingOperatorState( opID1, jobVertex1.getParallelism(), jobVertex1.getMaxParallelism())); operatorStates.put(opID2, new SpyInjectingOperatorState( opID2, jobVertex2.getParallelism(), jobVertex2.getMaxParallelism())); for (int index = 0; index < jobVertex1.getParallelism(); index++) { SubtaskState subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index)); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskState); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } for (int index = 0; index < jobVertex2.getParallelism(); index++) { SubtaskState subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index)); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskState); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints(); assertEquals(1, completedCheckpoints.size()); // shutdown the store store.shutdown(JobStatus.SUSPENDED); // restore the store Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>(); tasks.put(jobVertexID1, jobVertex1); tasks.put(jobVertexID2, jobVertex2); coord.restoreLatestCheckpointedState(tasks, true, false); // validate that all shared states are registered again after the recovery. for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) { for (OperatorState taskState : completedCheckpoint.getOperatorStates().values()) { for (OperatorSubtaskState subtaskState : taskState.getStates()) { verify(subtaskState, times(2)).registerSharedStates(any(SharedStateRegistry.class)); } } } // verify the restored state verifyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1); verifyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2); } /** * Tests that the checkpoint restoration fails if the max parallelism of the job vertices has * changed. * * @throws Exception */ @Test(expected=IllegalStateException.class) public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws Exception { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); final JobVertexID jobVertexID1 = new JobVertexID(); final JobVertexID jobVertexID2 = new JobVertexID(); int parallelism1 = 3; int parallelism2 = 2; int maxParallelism1 = 42; int maxParallelism2 = 13; final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex( jobVertexID1, parallelism1, maxParallelism1); final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex( jobVertexID2, parallelism2, maxParallelism2); List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2); allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices())); allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices())); ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), arrayExecutionVertices, arrayExecutionVertices, arrayExecutionVertices, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); // trigger the checkpoint coord.triggerCheckpoint(timestamp, false); assertTrue(coord.getPendingCheckpoints().keySet().size() == 1); long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); for (int index = 0; index < jobVertex1.getParallelism(); index++) { ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false); SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } for (int index = 0; index < jobVertex2.getParallelism(); index++) { ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID2, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false); SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints(); assertEquals(1, completedCheckpoints.size()); Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>(); int newMaxParallelism1 = 20; int newMaxParallelism2 = 42; final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex( jobVertexID1, parallelism1, newMaxParallelism1); final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex( jobVertexID2, parallelism2, newMaxParallelism2); tasks.put(jobVertexID1, newJobVertex1); tasks.put(jobVertexID2, newJobVertex2); coord.restoreLatestCheckpointedState(tasks, true, false); fail("The restoration should have failed because the max parallelism changed."); } /** * Tests that the checkpoint restoration fails if the parallelism of a job vertices with * non-partitioned state has changed. * * @throws Exception */ @Test(expected=IllegalStateException.class) public void testRestoreLatestCheckpointFailureWhenParallelismChanges() throws Exception { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); final JobVertexID jobVertexID1 = new JobVertexID(); final JobVertexID jobVertexID2 = new JobVertexID(); int parallelism1 = 3; int parallelism2 = 2; int maxParallelism1 = 42; int maxParallelism2 = 13; final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex( jobVertexID1, parallelism1, maxParallelism1); final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex( jobVertexID2, parallelism2, maxParallelism2); List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2); allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices())); allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices())); ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), arrayExecutionVertices, arrayExecutionVertices, arrayExecutionVertices, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); // trigger the checkpoint coord.triggerCheckpoint(timestamp, false); assertTrue(coord.getPendingCheckpoints().keySet().size() == 1); long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); for (int index = 0; index < jobVertex1.getParallelism(); index++) { ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState( jobVertexID1, keyGroupPartitions1.get(index), false); SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } for (int index = 0; index < jobVertex2.getParallelism(); index++) { ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID2, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState( jobVertexID2, keyGroupPartitions2.get(index), false); SubtaskState checkpointStateHandles = new SubtaskState(state, null, null, keyGroupState, null); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints(); assertEquals(1, completedCheckpoints.size()); Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>(); int newParallelism1 = 4; int newParallelism2 = 3; final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex( jobVertexID1, newParallelism1, maxParallelism1); final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex( jobVertexID2, newParallelism2, maxParallelism2); tasks.put(jobVertexID1, newJobVertex1); tasks.put(jobVertexID2, newJobVertex2); coord.restoreLatestCheckpointedState(tasks, true, false); fail("The restoration should have failed because the parallelism of an vertex with " + "non-partitioned state changed."); } @Test public void testRestoreLatestCheckpointedStateScaleIn() throws Exception { testRestoreLatestCheckpointedStateWithChangingParallelism(false); } @Test public void testRestoreLatestCheckpointedStateScaleOut() throws Exception { testRestoreLatestCheckpointedStateWithChangingParallelism(false); } @Test public void testStateRecoveryWhenTopologyChangeOut() throws Exception { testStateRecoveryWithTopologyChange(0); } @Test public void testStateRecoveryWhenTopologyChangeIn() throws Exception { testStateRecoveryWithTopologyChange(1); } @Test public void testStateRecoveryWhenTopologyChange() throws Exception { testStateRecoveryWithTopologyChange(2); } /** * Tests the checkpoint restoration with changing parallelism of job vertex with partitioned * state. * * @throws Exception */ private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean scaleOut) throws Exception { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); final JobVertexID jobVertexID1 = new JobVertexID(); final JobVertexID jobVertexID2 = new JobVertexID(); int parallelism1 = 3; int parallelism2 = scaleOut ? 2 : 13; int maxParallelism1 = 42; int maxParallelism2 = 13; int newParallelism2 = scaleOut ? 13 : 2; final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex( jobVertexID1, parallelism1, maxParallelism1); final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex( jobVertexID2, parallelism2, maxParallelism2); List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2); allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices())); allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices())); ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), arrayExecutionVertices, arrayExecutionVertices, arrayExecutionVertices, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); // trigger the checkpoint coord.triggerCheckpoint(timestamp, false); assertTrue(coord.getPendingCheckpoints().keySet().size() == 1); long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); //vertex 1 for (int index = 0; index < jobVertex1.getParallelism(); index++) { ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index); ChainedStateHandle<OperatorStateHandle> opStateBackend = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false); KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false); KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), true); SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } //vertex 2 final List<ChainedStateHandle<OperatorStateHandle>> expectedOpStatesBackend = new ArrayList<>(jobVertex2.getParallelism()); final List<ChainedStateHandle<OperatorStateHandle>> expectedOpStatesRaw = new ArrayList<>(jobVertex2.getParallelism()); for (int index = 0; index < jobVertex2.getParallelism(); index++) { KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false); KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), true); ChainedStateHandle<OperatorStateHandle> opStateBackend = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false); ChainedStateHandle<OperatorStateHandle> opStateRaw = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, true); expectedOpStatesBackend.add(opStateBackend); expectedOpStatesRaw.add(opStateRaw); SubtaskState checkpointStateHandles = new SubtaskState(new ChainedStateHandle<>( Collections.<StreamStateHandle>singletonList(null)), opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints(); assertEquals(1, completedCheckpoints.size()); Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>(); List<KeyGroupRange> newKeyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, newParallelism2); final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex( jobVertexID1, parallelism1, maxParallelism1); // rescale vertex 2 final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex( jobVertexID2, newParallelism2, maxParallelism2); tasks.put(jobVertexID1, newJobVertex1); tasks.put(jobVertexID2, newJobVertex2); coord.restoreLatestCheckpointedState(tasks, true, false); // verify the restored state verifyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1); List<List<Collection<OperatorStateHandle>>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism()); List<List<Collection<OperatorStateHandle>>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism()); for (int i = 0; i < newJobVertex2.getParallelism(); i++) { KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false); KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true); TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); ChainedStateHandle<StreamStateHandle> operatorState = taskStateHandles.getLegacyOperatorState(); List<Collection<OperatorStateHandle>> opStateBackend = taskStateHandles.getManagedOperatorState(); List<Collection<OperatorStateHandle>> opStateRaw = taskStateHandles.getRawOperatorState(); Collection<KeyedStateHandle> keyedStateBackend = taskStateHandles.getManagedKeyedState(); Collection<KeyedStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState(); actualOpStatesBackend.add(opStateBackend); actualOpStatesRaw.add(opStateRaw); // the 'non partition state' is not null because it is recombined. assertNotNull(operatorState); for (int index = 0; index < operatorState.getLength(); index++) { assertNull(operatorState.get(index)); } compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw); } comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend); comparePartitionableState(expectedOpStatesRaw, actualOpStatesRaw); } private static Tuple2<JobVertexID, OperatorID> generateIDPair() { JobVertexID jobVertexID = new JobVertexID(); OperatorID operatorID = OperatorID.fromJobVertexID(jobVertexID); return new Tuple2<>(jobVertexID, operatorID); } /** * old topology * [operator1,operator2] * parallelism1 -> [operator3,operator4] * parallelism2 * * * new topology * * [operator5,operator1,operator3] * newParallelism1 -> [operator3, operator6] * newParallelism2 * * scaleType: * 0 increase parallelism * 1 decrease parallelism * 2 same parallelism */ public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception { /** * Old topology * CHAIN(op1 -> op2) * parallelism1 -> CHAIN(op3 -> op4) * parallelism2 */ Tuple2<JobVertexID, OperatorID> id1 = generateIDPair(); Tuple2<JobVertexID, OperatorID> id2 = generateIDPair(); int parallelism1 = 10; int maxParallelism1 = 64; Tuple2<JobVertexID, OperatorID> id3 = generateIDPair(); Tuple2<JobVertexID, OperatorID> id4 = generateIDPair(); int parallelism2 = 10; int maxParallelism2 = 64; List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); Map<OperatorID, OperatorState> operatorStates = new HashMap<>(); //prepare vertex1 state for (Tuple2<JobVertexID, OperatorID> id : Lists.newArrayList(id1, id2)) { OperatorState taskState = new OperatorState(id.f1, parallelism1, maxParallelism1); operatorStates.put(id.f1, taskState); for (int index = 0; index < taskState.getParallelism(); index++) { StreamStateHandle subNonPartitionedState = generateStateForVertex(id.f0, index) .get(0); OperatorStateHandle subManagedOperatorState = generateChainedPartitionableStateHandle(id.f0, index, 2, 8, false) .get(0); OperatorStateHandle subRawOperatorState = generateChainedPartitionableStateHandle(id.f0, index, 2, 8, true) .get(0); OperatorSubtaskState subtaskState = new OperatorSubtaskState(subNonPartitionedState, subManagedOperatorState, subRawOperatorState, null, null); taskState.putState(index, subtaskState); } } List<List<ChainedStateHandle<OperatorStateHandle>>> expectedManagedOperatorStates = new ArrayList<>(); List<List<ChainedStateHandle<OperatorStateHandle>>> expectedRawOperatorStates = new ArrayList<>(); //prepare vertex2 state for (Tuple2<JobVertexID, OperatorID> id : Lists.newArrayList(id3, id4)) { OperatorState operatorState = new OperatorState(id.f1, parallelism2, maxParallelism2); operatorStates.put(id.f1, operatorState); List<ChainedStateHandle<OperatorStateHandle>> expectedManagedOperatorState = new ArrayList<>(); List<ChainedStateHandle<OperatorStateHandle>> expectedRawOperatorState = new ArrayList<>(); expectedManagedOperatorStates.add(expectedManagedOperatorState); expectedRawOperatorStates.add(expectedRawOperatorState); for (int index = 0; index < operatorState.getParallelism(); index++) { OperatorStateHandle subManagedOperatorState = generateChainedPartitionableStateHandle(id.f0, index, 2, 8, false) .get(0); OperatorStateHandle subRawOperatorState = generateChainedPartitionableStateHandle(id.f0, index, 2, 8, true) .get(0); KeyGroupsStateHandle subManagedKeyedState = id.f0.equals(id3.f0) ? generateKeyGroupState(id.f0, keyGroupPartitions2.get(index), false) : null; KeyGroupsStateHandle subRawKeyedState = id.f0.equals(id3.f0) ? generateKeyGroupState(id.f0, keyGroupPartitions2.get(index), true) : null; expectedManagedOperatorState.add(ChainedStateHandle.wrapSingleHandle(subManagedOperatorState)); expectedRawOperatorState.add(ChainedStateHandle.wrapSingleHandle(subRawOperatorState)); OperatorSubtaskState subtaskState = new OperatorSubtaskState( null, subManagedOperatorState, subRawOperatorState, subManagedKeyedState, subRawKeyedState); operatorState.putState(index, subtaskState); } } /** * New topology * CHAIN(op5 -> op1 -> op2) * newParallelism1 -> CHAIN(op3 -> op6) * newParallelism2 */ Tuple2<JobVertexID, OperatorID> id5 = generateIDPair(); int newParallelism1 = 10; Tuple2<JobVertexID, OperatorID> id6 = generateIDPair(); int newParallelism2 = parallelism2; if (scaleType == 0) { newParallelism2 = 20; } else if (scaleType == 1) { newParallelism2 = 8; } List<KeyGroupRange> newKeyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, newParallelism2); final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex( id5.f0, Lists.newArrayList(id2.f1, id1.f1, id5.f1), newParallelism1, maxParallelism1); final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex( id3.f0, Lists.newArrayList(id6.f1, id3.f1), newParallelism2, maxParallelism2); Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>(); tasks.put(id5.f0, newJobVertex1); tasks.put(id3.f0, newJobVertex2); JobID jobID = new JobID(); StandaloneCompletedCheckpointStore standaloneCompletedCheckpointStore = spy(new StandaloneCompletedCheckpointStore(1)); CompletedCheckpoint completedCheckpoint = new CompletedCheckpoint( jobID, 2, System.currentTimeMillis(), System.currentTimeMillis() + 3000, operatorStates, Collections.<MasterState>emptyList(), CheckpointProperties.forStandardCheckpoint(), null, null); when(standaloneCompletedCheckpointStore.getLatestCheckpoint()).thenReturn(completedCheckpoint); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( new JobID(), 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), newJobVertex1.getTaskVertices(), newJobVertex1.getTaskVertices(), newJobVertex1.getTaskVertices(), new StandaloneCheckpointIDCounter(), standaloneCompletedCheckpointStore, null, Executors.directExecutor()); coord.restoreLatestCheckpointedState(tasks, false, true); for (int i = 0; i < newJobVertex1.getParallelism(); i++) { TaskStateHandles taskStateHandles = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); ChainedStateHandle<StreamStateHandle> actualSubNonPartitionedState = taskStateHandles.getLegacyOperatorState(); List<Collection<OperatorStateHandle>> actualSubManagedOperatorState = taskStateHandles.getManagedOperatorState(); List<Collection<OperatorStateHandle>> actualSubRawOperatorState = taskStateHandles.getRawOperatorState(); assertNull(taskStateHandles.getManagedKeyedState()); assertNull(taskStateHandles.getRawKeyedState()); // operator5 { int operatorIndexInChain = 2; assertNull(actualSubNonPartitionedState.get(operatorIndexInChain)); assertNull(actualSubManagedOperatorState.get(operatorIndexInChain)); assertNull(actualSubRawOperatorState.get(operatorIndexInChain)); } // operator1 { int operatorIndexInChain = 1; ChainedStateHandle<StreamStateHandle> expectSubNonPartitionedState = generateStateForVertex(id1.f0, i); ChainedStateHandle<OperatorStateHandle> expectedManagedOpState = generateChainedPartitionableStateHandle( id1.f0, i, 2, 8, false); ChainedStateHandle<OperatorStateHandle> expectedRawOpState = generateChainedPartitionableStateHandle( id1.f0, i, 2, 8, true); assertTrue(CommonTestUtils.isSteamContentEqual( expectSubNonPartitionedState.get(0).openInputStream(), actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream())); assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(), actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(), actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); } // operator2 { int operatorIndexInChain = 0; ChainedStateHandle<StreamStateHandle> expectSubNonPartitionedState = generateStateForVertex(id2.f0, i); ChainedStateHandle<OperatorStateHandle> expectedManagedOpState = generateChainedPartitionableStateHandle( id2.f0, i, 2, 8, false); ChainedStateHandle<OperatorStateHandle> expectedRawOpState = generateChainedPartitionableStateHandle( id2.f0, i, 2, 8, true); assertTrue(CommonTestUtils.isSteamContentEqual(expectSubNonPartitionedState.get(0).openInputStream(), actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream())); assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(), actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(), actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); } } List<List<Collection<OperatorStateHandle>>> actualManagedOperatorStates = new ArrayList<>(newJobVertex2.getParallelism()); List<List<Collection<OperatorStateHandle>>> actualRawOperatorStates = new ArrayList<>(newJobVertex2.getParallelism()); for (int i = 0; i < newJobVertex2.getParallelism(); i++) { TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); // operator 3 { int operatorIndexInChain = 1; List<Collection<OperatorStateHandle>> actualSubManagedOperatorState = new ArrayList<>(1); actualSubManagedOperatorState.add(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain)); List<Collection<OperatorStateHandle>> actualSubRawOperatorState = new ArrayList<>(1); actualSubRawOperatorState.add(taskStateHandles.getRawOperatorState().get(operatorIndexInChain)); actualManagedOperatorStates.add(actualSubManagedOperatorState); actualRawOperatorStates.add(actualSubRawOperatorState); assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain)); } // operator 6 { int operatorIndexInChain = 0; assertNull(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain)); assertNull(taskStateHandles.getRawOperatorState().get(operatorIndexInChain)); assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain)); } KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), false); KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), true); Collection<KeyedStateHandle> keyedStateBackend = taskStateHandles.getManagedKeyedState(); Collection<KeyedStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState(); compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw); } comparePartitionableState(expectedManagedOperatorStates.get(0), actualManagedOperatorStates); comparePartitionableState(expectedRawOperatorStates.get(0), actualRawOperatorStates); } /** * Tests that the externalized checkpoint configuration is respected. */ @Test public void testExternalizedCheckpoints() throws Exception { try { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); // create some mock Execution vertices that receive the checkpoint trigger messages final ExecutionAttemptID attemptID1 = new ExecutionAttemptID(); ExecutionVertex vertex1 = mockExecutionVertex(attemptID1); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.externalizeCheckpoints(true), new ExecutionVertex[] { vertex1 }, new ExecutionVertex[] { vertex1 }, new ExecutionVertex[] { vertex1 }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), "fake-directory", Executors.directExecutor()); assertTrue(coord.triggerCheckpoint(timestamp, false)); for (PendingCheckpoint checkpoint : coord.getPendingCheckpoints().values()) { CheckpointProperties props = checkpoint.getProps(); CheckpointProperties expected = CheckpointProperties.forExternalizedCheckpoint(true); assertEquals(expected, props); } // the now we should have a completed checkpoint coord.shutdown(JobStatus.FINISHED); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testReplicateModeStateHandle() { Map<String, OperatorStateHandle.StateMetaInfo> metaInfoMap = new HashMap<>(1); metaInfoMap.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0, 23}, OperatorStateHandle.Mode.BROADCAST)); metaInfoMap.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{42, 64}, OperatorStateHandle.Mode.BROADCAST)); metaInfoMap.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{72, 83}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); OperatorStateHandle osh = new OperatorStateHandle(metaInfoMap, new ByteStreamStateHandle("test", new byte[100])); OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; List<Collection<OperatorStateHandle>> repartitionedStates = repartitioner.repartitionState(Collections.singletonList(osh), 3); Map<String, Integer> checkCounts = new HashMap<>(3); for (Collection<OperatorStateHandle> operatorStateHandles : repartitionedStates) { for (OperatorStateHandle operatorStateHandle : operatorStateHandles) { for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> stateNameToMetaInfo : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) { String stateName = stateNameToMetaInfo.getKey(); Integer count = checkCounts.get(stateName); if (null == count) { checkCounts.put(stateName, 1); } else { checkCounts.put(stateName, 1 + count); } OperatorStateHandle.StateMetaInfo stateMetaInfo = stateNameToMetaInfo.getValue(); if (OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.equals(stateMetaInfo.getDistributionMode())) { Assert.assertEquals(1, stateNameToMetaInfo.getValue().getOffsets().length); } else { Assert.assertEquals(2, stateNameToMetaInfo.getValue().getOffsets().length); } } } } Assert.assertEquals(3, checkCounts.size()); Assert.assertEquals(3, checkCounts.get("t-1").intValue()); Assert.assertEquals(3, checkCounts.get("t-2").intValue()); Assert.assertEquals(2, checkCounts.get("t-3").intValue()); } // ------------------------------------------------------------------------ // Utilities // ------------------------------------------------------------------------ public static KeyGroupsStateHandle generateKeyGroupState( JobVertexID jobVertexID, KeyGroupRange keyGroupPartition, boolean rawState) throws IOException { List<Integer> testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups()); // generate state for one keygroup for (int keyGroupIndex : keyGroupPartition) { int vertexHash = jobVertexID.hashCode(); int seed = rawState ? (vertexHash * (31 + keyGroupIndex)) : (vertexHash + keyGroupIndex); Random random = new Random(seed); int simulatedStateValue = random.nextInt(); testStatesLists.add(simulatedStateValue); } return generateKeyGroupState(keyGroupPartition, testStatesLists); } public static KeyGroupsStateHandle generateKeyGroupState( KeyGroupRange keyGroupRange, List<? extends Serializable> states) throws IOException { Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size()); Tuple2<byte[], List<long[]>> serializedDataWithOffsets = serializeTogetherAndTrackOffsets(Collections.<List<? extends Serializable>>singletonList(states)); KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0)); ByteStreamStateHandle allSerializedStatesHandle = new TestByteStreamStateHandleDeepCompare( String.valueOf(UUID.randomUUID()), serializedDataWithOffsets.f0); KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle( keyGroupRangeOffsets, allSerializedStatesHandle); return keyGroupsStateHandle; } public static Tuple2<byte[], List<long[]>> serializeTogetherAndTrackOffsets( List<List<? extends Serializable>> serializables) throws IOException { List<long[]> offsets = new ArrayList<>(serializables.size()); List<byte[]> serializedGroupValues = new ArrayList<>(); int runningGroupsOffset = 0; for(List<? extends Serializable> list : serializables) { long[] currentOffsets = new long[list.size()]; offsets.add(currentOffsets); for (int i = 0; i < list.size(); ++i) { currentOffsets[i] = runningGroupsOffset; byte[] serializedValue = InstantiationUtil.serializeObject(list.get(i)); serializedGroupValues.add(serializedValue); runningGroupsOffset += serializedValue.length; } } //write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray byte[] allSerializedValuesConcatenated = new byte[runningGroupsOffset]; runningGroupsOffset = 0; for (byte[] serializedGroupValue : serializedGroupValues) { System.arraycopy( serializedGroupValue, 0, allSerializedValuesConcatenated, runningGroupsOffset, serializedGroupValue.length); runningGroupsOffset += serializedGroupValue.length; } return new Tuple2<>(allSerializedValuesConcatenated, offsets); } public static ChainedStateHandle<StreamStateHandle> generateStateForVertex( JobVertexID jobVertexID, int index) throws IOException { Random random = new Random(jobVertexID.hashCode() + index); int value = random.nextInt(); return generateChainedStateHandle(value); } public static ChainedStateHandle<StreamStateHandle> generateChainedStateHandle( Serializable value) throws IOException { return ChainedStateHandle.wrapSingleHandle( TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value)); } public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle( JobVertexID jobVertexID, int index, int namedStates, int partitionsPerState, boolean rawState) throws IOException { Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates); for (int i = 0; i < namedStates; ++i) { List<Integer> testStatesLists = new ArrayList<>(partitionsPerState); // generate state int seed = jobVertexID.hashCode() * index + i * namedStates; if (rawState) { seed = (seed + 1) * 31; } Random random = new Random(seed); for (int j = 0; j < partitionsPerState; ++j) { int simulatedStateValue = random.nextInt(); testStatesLists.add(simulatedStateValue); } statesListsMap.put("state-" + i, testStatesLists); } return generateChainedPartitionableStateHandle(statesListsMap); } private static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle( Map<String, List<? extends Serializable>> states) throws IOException { List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size()); for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) { namedStateSerializables.add(entry.getValue()); } Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables); Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(states.size()); int idx = 0; for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) { offsetsMap.put( entry.getKey(), new OperatorStateHandle.StateMetaInfo( serializationWithOffsets.f1.get(idx), OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); ++idx; } ByteStreamStateHandle streamStateHandle = new TestByteStreamStateHandleDeepCompare( String.valueOf(UUID.randomUUID()), serializationWithOffsets.f0); OperatorStateHandle operatorStateHandle = new OperatorStateHandle(offsetsMap, streamStateHandle); return ChainedStateHandle.wrapSingleHandle(operatorStateHandle); } static ExecutionJobVertex mockExecutionJobVertex( JobVertexID jobVertexID, int parallelism, int maxParallelism) { return mockExecutionJobVertex( jobVertexID, Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)), parallelism, maxParallelism ); } static ExecutionJobVertex mockExecutionJobVertex( JobVertexID jobVertexID, List<OperatorID> jobVertexIDs, int parallelism, int maxParallelism) { final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class); ExecutionVertex[] executionVertices = new ExecutionVertex[parallelism]; for (int i = 0; i < parallelism; i++) { executionVertices[i] = mockExecutionVertex( new ExecutionAttemptID(), jobVertexID, jobVertexIDs, parallelism, maxParallelism, ExecutionState.RUNNING); when(executionVertices[i].getParallelSubtaskIndex()).thenReturn(i); } when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID); when(executionJobVertex.getTaskVertices()).thenReturn(executionVertices); when(executionJobVertex.getParallelism()).thenReturn(parallelism); when(executionJobVertex.getMaxParallelism()).thenReturn(maxParallelism); when(executionJobVertex.isMaxParallelismConfigured()).thenReturn(true); when(executionJobVertex.getOperatorIDs()).thenReturn(jobVertexIDs); when(executionJobVertex.getUserDefinedOperatorIDs()).thenReturn(Arrays.asList(new OperatorID[jobVertexIDs.size()])); return executionJobVertex; } static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) { JobVertexID jobVertexID = new JobVertexID(); return mockExecutionVertex( attemptID, jobVertexID, Arrays.asList(OperatorID.fromJobVertexID(jobVertexID)), 1, 1, ExecutionState.RUNNING); } private static ExecutionVertex mockExecutionVertex( ExecutionAttemptID attemptID, JobVertexID jobVertexID, List<OperatorID> jobVertexIDs, int parallelism, int maxParallelism, ExecutionState state, ExecutionState ... successiveStates) { ExecutionVertex vertex = mock(ExecutionVertex.class); final Execution exec = spy(new Execution( mock(Executor.class), vertex, 1, 1L, 1L, Time.milliseconds(500L) )); when(exec.getAttemptId()).thenReturn(attemptID); when(exec.getState()).thenReturn(state, successiveStates); when(vertex.getJobvertexId()).thenReturn(jobVertexID); when(vertex.getCurrentExecutionAttempt()).thenReturn(exec); when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism); when(vertex.getMaxParallelism()).thenReturn(maxParallelism); ExecutionJobVertex jobVertex = mock(ExecutionJobVertex.class); when(jobVertex.getOperatorIDs()).thenReturn(jobVertexIDs); when(vertex.getJobVertex()).thenReturn(jobVertex); return vertex; } static SubtaskState mockSubtaskState( JobVertexID jobVertexID, int index, KeyGroupRange keyGroupRange) throws IOException { ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID, index); ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID, index, 2, 8, false); KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false); SubtaskState subtaskState = mock(SubtaskState.class, withSettings().serializable()); doReturn(nonPartitionedState).when(subtaskState).getLegacyOperatorState(); doReturn(partitionableState).when(subtaskState).getManagedOperatorState(); doReturn(null).when(subtaskState).getRawOperatorState(); doReturn(partitionedKeyGroupState).when(subtaskState).getManagedKeyedState(); doReturn(null).when(subtaskState).getRawKeyedState(); return subtaskState; } public static void verifyStateRestore( JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex, List<KeyGroupRange> keyGroupPartitions) throws Exception { for (int i = 0; i < executionJobVertex.getParallelism(); i++) { TaskStateHandles taskStateHandles = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); ChainedStateHandle<StreamStateHandle> expectNonPartitionedState = generateStateForVertex(jobVertexID, i); ChainedStateHandle<StreamStateHandle> actualNonPartitionedState = taskStateHandles.getLegacyOperatorState(); assertTrue(CommonTestUtils.isSteamContentEqual( expectNonPartitionedState.get(0).openInputStream(), actualNonPartitionedState.get(0).openInputStream())); ChainedStateHandle<OperatorStateHandle> expectedOpStateBackend = generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false); List<Collection<OperatorStateHandle>> actualPartitionableState = taskStateHandles.getManagedOperatorState(); assertTrue(CommonTestUtils.isSteamContentEqual( expectedOpStateBackend.get(0).openInputStream(), actualPartitionableState.get(0).iterator().next().openInputStream())); KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState( jobVertexID, keyGroupPartitions.get(i), false); Collection<KeyedStateHandle> actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState(); compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), actualPartitionedKeyGroupState); } } public static void compareKeyedState( Collection<KeyGroupsStateHandle> expectPartitionedKeyGroupState, Collection<? extends KeyedStateHandle> actualPartitionedKeyGroupState) throws Exception { KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.iterator().next(); int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getKeyGroupRange().getNumberOfKeyGroups(); int actualTotalKeyGroups = 0; for(KeyedStateHandle keyedStateHandle: actualPartitionedKeyGroupState) { assertTrue(keyedStateHandle instanceof KeyGroupsStateHandle); actualTotalKeyGroups += keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups(); } assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups); try (FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.openInputStream()) { for (int groupId : expectedHeadOpKeyGroupStateHandle.getKeyGroupRange()) { long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId); inputStream.seek(offset); int expectedKeyGroupState = InstantiationUtil.deserializeObject(inputStream, Thread.currentThread().getContextClassLoader()); for (KeyedStateHandle oneActualKeyedStateHandle : actualPartitionedKeyGroupState) { assertTrue(oneActualKeyedStateHandle instanceof KeyGroupsStateHandle); KeyGroupsStateHandle oneActualKeyGroupStateHandle = (KeyGroupsStateHandle) oneActualKeyedStateHandle; if (oneActualKeyGroupStateHandle.getKeyGroupRange().contains(groupId)) { long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId); try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) { actualInputStream.seek(actualOffset); int actualGroupState = InstantiationUtil. deserializeObject(actualInputStream, Thread.currentThread().getContextClassLoader()); assertEquals(expectedKeyGroupState, actualGroupState); } } } } } } public static void comparePartitionableState( List<ChainedStateHandle<OperatorStateHandle>> expected, List<List<Collection<OperatorStateHandle>>> actual) throws Exception { List<String> expectedResult = new ArrayList<>(); for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle : expected) { for (int i = 0; i < chainedStateHandle.getLength(); ++i) { OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i); collectResult(i, operatorStateHandle, expectedResult); } } Collections.sort(expectedResult); List<String> actualResult = new ArrayList<>(); for (List<Collection<OperatorStateHandle>> collectionList : actual) { if (collectionList != null) { for (int i = 0; i < collectionList.size(); ++i) { Collection<OperatorStateHandle> stateHandles = collectionList.get(i); Assert.assertNotNull(stateHandles); for (OperatorStateHandle operatorStateHandle : stateHandles) { collectResult(i, operatorStateHandle, actualResult); } } } } Collections.sort(actualResult); Assert.assertEquals(expectedResult, actualResult); } private static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List<String> resultCollector) throws Exception { try (FSDataInputStream in = operatorStateHandle.openInputStream()) { for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) { for (long offset : entry.getValue().getOffsets()) { in.seek(offset); Integer state = InstantiationUtil. deserializeObject(in, Thread.currentThread().getContextClassLoader()); resultCollector.add(opIdx + " : " + entry.getKey() + " : " + state); } } } } @Test public void testCreateKeyGroupPartitions() { testCreateKeyGroupPartitions(1, 1); testCreateKeyGroupPartitions(13, 1); testCreateKeyGroupPartitions(13, 2); testCreateKeyGroupPartitions(Short.MAX_VALUE, 1); testCreateKeyGroupPartitions(Short.MAX_VALUE, 13); testCreateKeyGroupPartitions(Short.MAX_VALUE, Short.MAX_VALUE); Random r = new Random(1234); for (int k = 0; k < 1000; ++k) { int maxParallelism = 1 + r.nextInt(Short.MAX_VALUE - 1); int parallelism = 1 + r.nextInt(maxParallelism); testCreateKeyGroupPartitions(maxParallelism, parallelism); } } @Test public void testStopPeriodicScheduler() throws Exception { // create some mock Execution vertices that receive the checkpoint trigger messages final ExecutionAttemptID attemptID1 = new ExecutionAttemptID(); ExecutionVertex vertex1 = mockExecutionVertex(attemptID1); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( new JobID(), 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { vertex1 }, new ExecutionVertex[] { vertex1 }, new ExecutionVertex[] { vertex1 }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); // Periodic CheckpointTriggerResult triggerResult = coord.triggerCheckpoint( System.currentTimeMillis(), CheckpointProperties.forStandardCheckpoint(), null, true); assertTrue(triggerResult.isFailure()); assertEquals(CheckpointDeclineReason.PERIODIC_SCHEDULER_SHUTDOWN, triggerResult.getFailureReason()); // Not periodic triggerResult = coord.triggerCheckpoint( System.currentTimeMillis(), CheckpointProperties.forStandardCheckpoint(), null, false); assertFalse(triggerResult.isFailure()); } private void testCreateKeyGroupPartitions(int maxParallelism, int parallelism) { List<KeyGroupRange> ranges = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism, parallelism); for (int i = 0; i < maxParallelism; ++i) { KeyGroupRange range = ranges.get(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, i)); if (!range.contains(i)) { Assert.fail("Could not find expected key-group " + i + " in range " + range); } } } @Test public void testPartitionableStateRepartitioning() { Random r = new Random(42); for (int run = 0; run < 10000; ++run) { int oldParallelism = 1 + r.nextInt(9); int newParallelism = 1 + r.nextInt(9); int numNamedStates = 1 + r.nextInt(9); int maxPartitionsPerState = 1 + r.nextInt(9); doTestPartitionableStateRepartitioning( r, oldParallelism, newParallelism, numNamedStates, maxPartitionsPerState); } } private void doTestPartitionableStateRepartitioning( Random r, int oldParallelism, int newParallelism, int numNamedStates, int maxPartitionsPerState) { List<OperatorStateHandle> previousParallelOpInstanceStates = new ArrayList<>(oldParallelism); for (int i = 0; i < oldParallelism; ++i) { Path fakePath = new Path("/fake-" + i); Map<String, OperatorStateHandle.StateMetaInfo> namedStatesToOffsets = new HashMap<>(); int off = 0; for (int s = 0; s < numNamedStates; ++s) { long[] offs = new long[1 + r.nextInt(maxPartitionsPerState)]; for (int o = 0; o < offs.length; ++o) { offs[o] = off; ++off; } OperatorStateHandle.Mode mode = r.nextInt(10) == 0 ? OperatorStateHandle.Mode.BROADCAST : OperatorStateHandle.Mode.SPLIT_DISTRIBUTE; namedStatesToOffsets.put( "State-" + s, new OperatorStateHandle.StateMetaInfo(offs, mode)); } previousParallelOpInstanceStates.add( new OperatorStateHandle(namedStatesToOffsets, new FileStateHandle(fakePath, -1))); } Map<StreamStateHandle, Map<String, List<Long>>> expected = new HashMap<>(); int expectedTotalPartitions = 0; for (OperatorStateHandle psh : previousParallelOpInstanceStates) { Map<String, OperatorStateHandle.StateMetaInfo> offsMap = psh.getStateNameToPartitionOffsets(); Map<String, List<Long>> offsMapWithList = new HashMap<>(offsMap.size()); for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e : offsMap.entrySet()) { long[] offs = e.getValue().getOffsets(); int replication = e.getValue().getDistributionMode().equals(OperatorStateHandle.Mode.BROADCAST) ? newParallelism : 1; expectedTotalPartitions += replication * offs.length; List<Long> offsList = new ArrayList<>(offs.length); for (int i = 0; i < offs.length; ++i) { for(int p = 0; p < replication; ++p) { offsList.add(offs[i]); } } offsMapWithList.put(e.getKey(), offsList); } expected.put(psh.getDelegateStateHandle(), offsMapWithList); } OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; List<Collection<OperatorStateHandle>> pshs = repartitioner.repartitionState(previousParallelOpInstanceStates, newParallelism); Map<StreamStateHandle, Map<String, List<Long>>> actual = new HashMap<>(); int minCount = Integer.MAX_VALUE; int maxCount = 0; int actualTotalPartitions = 0; for (int p = 0; p < newParallelism; ++p) { int partitionCount = 0; Collection<OperatorStateHandle> pshc = pshs.get(p); for (OperatorStateHandle sh : pshc) { for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> namedState : sh.getStateNameToPartitionOffsets().entrySet()) { Map<String, List<Long>> stateToOffsets = actual.get(sh.getDelegateStateHandle()); if (stateToOffsets == null) { stateToOffsets = new HashMap<>(); actual.put(sh.getDelegateStateHandle(), stateToOffsets); } List<Long> actualOffs = stateToOffsets.get(namedState.getKey()); if (actualOffs == null) { actualOffs = new ArrayList<>(); stateToOffsets.put(namedState.getKey(), actualOffs); } long[] add = namedState.getValue().getOffsets(); for (int i = 0; i < add.length; ++i) { actualOffs.add(add[i]); } partitionCount += namedState.getValue().getOffsets().length; } } minCount = Math.min(minCount, partitionCount); maxCount = Math.max(maxCount, partitionCount); actualTotalPartitions += partitionCount; } for (Map<String, List<Long>> v : actual.values()) { for (List<Long> l : v.values()) { Collections.sort(l); } } int maxLoadDiff = maxCount - minCount; Assert.assertTrue("Difference in partition load is > 1 : " + maxLoadDiff, maxLoadDiff <= 1); Assert.assertEquals(expectedTotalPartitions, actualTotalPartitions); Assert.assertEquals(expected, actual); } /** * Tests that the pending checkpoint stats callbacks are created. */ @Test public void testCheckpointStatsTrackerPendingCheckpointCallback() { final long timestamp = System.currentTimeMillis(); ExecutionVertex vertex1 = mockExecutionVertex(new ExecutionAttemptID()); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( new JobID(), 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[]{vertex1}, new ExecutionVertex[]{vertex1}, new ExecutionVertex[]{vertex1}, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor()); CheckpointStatsTracker tracker = mock(CheckpointStatsTracker.class); coord.setCheckpointStatsTracker(tracker); when(tracker.reportPendingCheckpoint(anyLong(), anyLong(), any(CheckpointProperties.class))) .thenReturn(mock(PendingCheckpointStats.class)); // Trigger a checkpoint and verify callback assertTrue(coord.triggerCheckpoint(timestamp, false)); verify(tracker, times(1)) .reportPendingCheckpoint(eq(1L), eq(timestamp), eq(CheckpointProperties.forStandardCheckpoint())); } /** * Tests that the restore callbacks are called if registered. */ @Test public void testCheckpointStatsTrackerRestoreCallback() throws Exception { ExecutionVertex vertex1 = mockExecutionVertex(new ExecutionAttemptID()); StandaloneCompletedCheckpointStore store = new StandaloneCompletedCheckpointStore(1); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( new JobID(), 600000, 600000, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[]{vertex1}, new ExecutionVertex[]{vertex1}, new ExecutionVertex[]{vertex1}, new StandaloneCheckpointIDCounter(), store, null, Executors.directExecutor()); store.addCheckpoint(new CompletedCheckpoint( new JobID(), 0, 0, 0, Collections.<OperatorID, OperatorState>emptyMap(), Collections.<MasterState>emptyList(), CheckpointProperties.forStandardCheckpoint(), null, null)); CheckpointStatsTracker tracker = mock(CheckpointStatsTracker.class); coord.setCheckpointStatsTracker(tracker); assertTrue(coord.restoreLatestCheckpointedState(Collections.<JobVertexID, ExecutionJobVertex>emptyMap(), false, true)); verify(tracker, times(1)) .reportRestoredCheckpoint(any(RestoredCheckpointStats.class)); } private static final class SpyInjectingOperatorState extends OperatorState { private static final long serialVersionUID = -4004437428483663815L; public SpyInjectingOperatorState(OperatorID taskID, int parallelism, int maxParallelism) { super(taskID, parallelism, maxParallelism); } public void putState(int subtaskIndex, OperatorSubtaskState subtaskState) { super.putState(subtaskIndex, spy(subtaskState)); } } }