/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.checkpoint;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.concurrent.Executors;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
import org.apache.flink.runtime.state.ChainedStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.TaskStateHandles;
import org.apache.flink.runtime.util.SerializableObject;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.junit.Test;
import org.mockito.Mockito;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* Tests concerning the restoring of state from a checkpoint to the task executions.
*/
public class CheckpointStateRestoreTest {
/**
* Tests that on restore the task state is reset for each stateful task.
*/
@Test
public void testSetState() {
try {
final ChainedStateHandle<StreamStateHandle> serializedState = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject());
KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
final KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
final JobID jid = new JobID();
final JobVertexID statefulId = new JobVertexID();
final JobVertexID statelessId = new JobVertexID();
Execution statefulExec1 = mockExecution();
Execution statefulExec2 = mockExecution();
Execution statefulExec3 = mockExecution();
Execution statelessExec1 = mockExecution();
Execution statelessExec2 = mockExecution();
ExecutionVertex stateful1 = mockExecutionVertex(statefulExec1, statefulId, 0, 3);
ExecutionVertex stateful2 = mockExecutionVertex(statefulExec2, statefulId, 1, 3);
ExecutionVertex stateful3 = mockExecutionVertex(statefulExec3, statefulId, 2, 3);
ExecutionVertex stateless1 = mockExecutionVertex(statelessExec1, statelessId, 0, 2);
ExecutionVertex stateless2 = mockExecutionVertex(statelessExec2, statelessId, 1, 2);
ExecutionJobVertex stateful = mockExecutionJobVertex(statefulId,
new ExecutionVertex[] { stateful1, stateful2, stateful3 });
ExecutionJobVertex stateless = mockExecutionJobVertex(statelessId,
new ExecutionVertex[] { stateless1, stateless2 });
Map<JobVertexID, ExecutionJobVertex> map = new HashMap<JobVertexID, ExecutionJobVertex>();
map.put(statefulId, stateful);
map.put(statelessId, stateless);
CheckpointCoordinator coord = new CheckpointCoordinator(
jid,
200000L,
200000L,
0,
Integer.MAX_VALUE,
ExternalizedCheckpointSettings.none(),
new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 },
new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 },
new ExecutionVertex[0],
new StandaloneCheckpointIDCounter(),
new StandaloneCompletedCheckpointStore(1),
null,
Executors.directExecutor());
// create ourselves a checkpoint with state
final long timestamp = 34623786L;
coord.triggerCheckpoint(timestamp, false);
PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next();
final long checkpointId = pending.getCheckpointId();
SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null);
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles));
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles));
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles));
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId));
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId));
assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
assertEquals(0, coord.getNumberOfPendingCheckpoints());
// let the coordinator inject the state
coord.restoreLatestCheckpointedState(map, true, false);
// verify that each stateful vertex got the state
final TaskStateHandles taskStateHandles = new TaskStateHandles(
serializedState,
Collections.<Collection<OperatorStateHandle>>singletonList(null),
Collections.<Collection<OperatorStateHandle>>singletonList(null),
Collections.singletonList(serializedKeyGroupStates),
null);
BaseMatcher<TaskStateHandles> matcher = new BaseMatcher<TaskStateHandles>() {
@Override
public boolean matches(Object o) {
if (o instanceof TaskStateHandles) {
return o.equals(taskStateHandles);
}
return false;
}
@Override
public void describeTo(Description description) {
description.appendValue(taskStateHandles);
}
};
verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher));
verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher));
verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher));
verify(statelessExec1, times(0)).setInitialState(Mockito.<TaskStateHandles>any());
verify(statelessExec2, times(0)).setInitialState(Mockito.<TaskStateHandles>any());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testNoCheckpointAvailable() {
try {
CheckpointCoordinator coord = new CheckpointCoordinator(
new JobID(),
200000L,
200000L,
0,
Integer.MAX_VALUE,
ExternalizedCheckpointSettings.none(),
new ExecutionVertex[] { mock(ExecutionVertex.class) },
new ExecutionVertex[] { mock(ExecutionVertex.class) },
new ExecutionVertex[0],
new StandaloneCheckpointIDCounter(),
new StandaloneCompletedCheckpointStore(1),
null,
Executors.directExecutor());
try {
coord.restoreLatestCheckpointedState(new HashMap<JobVertexID, ExecutionJobVertex>(), true, false);
fail("this should throw an exception");
}
catch (IllegalStateException e) {
// expected
}
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
/**
* Tests that the allow non restored state flag is correctly handled.
*
* The flag only applies for state that is part of the checkpoint.
*/
@Test
public void testNonRestoredState() throws Exception {
// --- (1) Create tasks to restore checkpoint with ---
JobVertexID jobVertexId1 = new JobVertexID();
JobVertexID jobVertexId2 = new JobVertexID();
OperatorID operatorId1 = OperatorID.fromJobVertexID(jobVertexId1);
// 1st JobVertex
ExecutionVertex vertex11 = mockExecutionVertex(mockExecution(), jobVertexId1, 0, 3);
ExecutionVertex vertex12 = mockExecutionVertex(mockExecution(), jobVertexId1, 1, 3);
ExecutionVertex vertex13 = mockExecutionVertex(mockExecution(), jobVertexId1, 2, 3);
// 2nd JobVertex
ExecutionVertex vertex21 = mockExecutionVertex(mockExecution(), jobVertexId2, 0, 2);
ExecutionVertex vertex22 = mockExecutionVertex(mockExecution(), jobVertexId2, 1, 2);
ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(jobVertexId1, new ExecutionVertex[] { vertex11, vertex12, vertex13 });
ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(jobVertexId2, new ExecutionVertex[] { vertex21, vertex22 });
Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
tasks.put(jobVertexId1, jobVertex1);
tasks.put(jobVertexId2, jobVertex2);
CheckpointCoordinator coord = new CheckpointCoordinator(
new JobID(),
Integer.MAX_VALUE,
Integer.MAX_VALUE,
0,
Integer.MAX_VALUE,
ExternalizedCheckpointSettings.none(),
new ExecutionVertex[] {},
new ExecutionVertex[] {},
new ExecutionVertex[] {},
new StandaloneCheckpointIDCounter(),
new StandaloneCompletedCheckpointStore(1),
null,
Executors.directExecutor());
StreamStateHandle serializedState = CheckpointCoordinatorTest
.generateChainedStateHandle(new SerializableObject())
.get(0);
// --- (2) Checkpoint misses state for a jobVertex (should work) ---
Map<OperatorID, OperatorState> checkpointTaskStates = new HashMap<>();
{
OperatorState taskState = new OperatorState(operatorId1, 3, 3);
taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null));
taskState.putState(1, new OperatorSubtaskState(serializedState, null, null, null, null));
taskState.putState(2, new OperatorSubtaskState(serializedState, null, null, null, null));
checkpointTaskStates.put(operatorId1, taskState);
}
CompletedCheckpoint checkpoint = new CompletedCheckpoint(
new JobID(),
0,
1,
2,
new HashMap<>(checkpointTaskStates),
Collections.<MasterState>emptyList(),
CheckpointProperties.forStandardCheckpoint(),
null,
null);
coord.getCheckpointStore().addCheckpoint(checkpoint);
coord.restoreLatestCheckpointedState(tasks, true, false);
coord.restoreLatestCheckpointedState(tasks, true, true);
// --- (3) JobVertex missing for task state that is part of the checkpoint ---
JobVertexID newJobVertexID = new JobVertexID();
OperatorID newOperatorID = OperatorID.fromJobVertexID(newJobVertexID);
// There is no task for this
{
OperatorState taskState = new OperatorState(newOperatorID, 1, 1);
taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null));
checkpointTaskStates.put(newOperatorID, taskState);
}
checkpoint = new CompletedCheckpoint(
new JobID(),
1,
2,
3,
new HashMap<>(checkpointTaskStates),
Collections.<MasterState>emptyList(),
CheckpointProperties.forStandardCheckpoint(),
null,
null);
coord.getCheckpointStore().addCheckpoint(checkpoint);
// (i) Allow non restored state (should succeed)
coord.restoreLatestCheckpointedState(tasks, true, true);
// (ii) Don't allow non restored state (should fail)
try {
coord.restoreLatestCheckpointedState(tasks, true, false);
fail("Did not throw the expected Exception.");
} catch (IllegalStateException ignored) {
}
}
// ------------------------------------------------------------------------
private Execution mockExecution() {
return mockExecution(ExecutionState.RUNNING);
}
private Execution mockExecution(ExecutionState state) {
Execution mock = mock(Execution.class);
when(mock.getAttemptId()).thenReturn(new ExecutionAttemptID());
when(mock.getState()).thenReturn(state);
return mock;
}
private ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID vertexId, int subtask, int parallelism) {
ExecutionVertex mock = mock(ExecutionVertex.class);
when(mock.getJobvertexId()).thenReturn(vertexId);
when(mock.getParallelSubtaskIndex()).thenReturn(subtask);
when(mock.getCurrentExecutionAttempt()).thenReturn(execution);
when(mock.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
when(mock.getMaxParallelism()).thenReturn(parallelism);
return mock;
}
private ExecutionJobVertex mockExecutionJobVertex(JobVertexID id, ExecutionVertex[] vertices) {
ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
when(vertex.getParallelism()).thenReturn(vertices.length);
when(vertex.getMaxParallelism()).thenReturn(vertices.length);
when(vertex.getJobVertexId()).thenReturn(id);
when(vertex.getTaskVertices()).thenReturn(vertices);
when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorID.fromJobVertexID(id)));
when(vertex.getUserDefinedOperatorIDs()).thenReturn(Collections.<OperatorID>singletonList(null));
for (ExecutionVertex v : vertices) {
when(v.getJobVertex()).thenReturn(vertex);
}
return vertex;
}
}