/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.checkpoint;
import org.apache.flink.api.common.JobID;
import org.apache.flink.core.io.SimpleVersionedSerializer;
import org.apache.flink.runtime.concurrent.Executors;
import org.apache.flink.runtime.concurrent.Future;
import org.apache.flink.runtime.concurrent.impl.FlinkCompletableFuture;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executor;
import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest.mockExecutionVertex;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isNull;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* Tests for the user-defined hooks that the checkpoint coordinator can call.
*/
public class CheckpointCoordinatorMasterHooksTest {
// ------------------------------------------------------------------------
// hook registration
// ------------------------------------------------------------------------
/**
* This method tests that hooks with the same identifier are not registered
* multiple times.
*/
@Test
public void testDeduplicateOnRegister() {
final CheckpointCoordinator cc = instantiateCheckpointCoordinator(new JobID());
MasterTriggerRestoreHook<?> hook1 = mock(MasterTriggerRestoreHook.class);
when(hook1.getIdentifier()).thenReturn("test id");
MasterTriggerRestoreHook<?> hook2 = mock(MasterTriggerRestoreHook.class);
when(hook2.getIdentifier()).thenReturn("test id");
MasterTriggerRestoreHook<?> hook3 = mock(MasterTriggerRestoreHook.class);
when(hook3.getIdentifier()).thenReturn("anotherId");
assertTrue(cc.addMasterHook(hook1));
assertFalse(cc.addMasterHook(hook2));
assertTrue(cc.addMasterHook(hook3));
}
/**
* Test that validates correct exceptions when supplying hooks with invalid IDs.
*/
@Test
public void testNullOrInvalidId() {
final CheckpointCoordinator cc = instantiateCheckpointCoordinator(new JobID());
try {
cc.addMasterHook(null);
fail("expected an exception");
} catch (NullPointerException ignored) {}
try {
cc.addMasterHook(mock(MasterTriggerRestoreHook.class));
fail("expected an exception");
} catch (IllegalArgumentException ignored) {}
try {
MasterTriggerRestoreHook<?> hook = mock(MasterTriggerRestoreHook.class);
when(hook.getIdentifier()).thenReturn(" ");
cc.addMasterHook(hook);
fail("expected an exception");
} catch (IllegalArgumentException ignored) {}
}
// ------------------------------------------------------------------------
// trigger / restore behavior
// ------------------------------------------------------------------------
@Test
public void testHooksAreCalledOnTrigger() throws Exception {
final String id1 = "id1";
final String id2 = "id2";
final String state1 = "the-test-string-state";
final byte[] state1serialized = new StringSerializer().serialize(state1);
final long state2 = 987654321L;
final byte[] state2serialized = new LongSerializer().serialize(state2);
final MasterTriggerRestoreHook<String> statefulHook1 = mockGeneric(MasterTriggerRestoreHook.class);
when(statefulHook1.getIdentifier()).thenReturn(id1);
when(statefulHook1.createCheckpointDataSerializer()).thenReturn(new StringSerializer());
when(statefulHook1.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class)))
.thenReturn(FlinkCompletableFuture.completed(state1));
final MasterTriggerRestoreHook<Long> statefulHook2 = mockGeneric(MasterTriggerRestoreHook.class);
when(statefulHook2.getIdentifier()).thenReturn(id2);
when(statefulHook2.createCheckpointDataSerializer()).thenReturn(new LongSerializer());
when(statefulHook2.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class)))
.thenReturn(FlinkCompletableFuture.completed(state2));
final MasterTriggerRestoreHook<Void> statelessHook = mockGeneric(MasterTriggerRestoreHook.class);
when(statelessHook.getIdentifier()).thenReturn("some-id");
// create the checkpoint coordinator
final JobID jid = new JobID();
final ExecutionAttemptID execId = new ExecutionAttemptID();
final ExecutionVertex ackVertex = mockExecutionVertex(execId);
final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex);
cc.addMasterHook(statefulHook1);
cc.addMasterHook(statelessHook);
cc.addMasterHook(statefulHook2);
// trigger a checkpoint
assertTrue(cc.triggerCheckpoint(System.currentTimeMillis(), false));
assertEquals(1, cc.getNumberOfPendingCheckpoints());
verify(statefulHook1, times(1)).triggerCheckpoint(anyLong(), anyLong(), any(Executor.class));
verify(statefulHook2, times(1)).triggerCheckpoint(anyLong(), anyLong(), any(Executor.class));
verify(statelessHook, times(1)).triggerCheckpoint(anyLong(), anyLong(), any(Executor.class));
final long checkpointId = cc.getPendingCheckpoints().values().iterator().next().getCheckpointId();
cc.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, execId, checkpointId));
assertEquals(0, cc.getNumberOfPendingCheckpoints());
assertEquals(1, cc.getNumberOfRetainedSuccessfulCheckpoints());
final CompletedCheckpoint chk = cc.getCheckpointStore().getLatestCheckpoint();
final Collection<MasterState> masterStates = chk.getMasterHookStates();
assertEquals(2, masterStates.size());
for (MasterState ms : masterStates) {
if (ms.name().equals(id1)) {
assertArrayEquals(state1serialized, ms.bytes());
assertEquals(StringSerializer.VERSION, ms.version());
}
else if (ms.name().equals(id2)) {
assertArrayEquals(state2serialized, ms.bytes());
assertEquals(LongSerializer.VERSION, ms.version());
}
else {
fail("unrecognized state name: " + ms.name());
}
}
}
@Test
public void testHooksAreCalledOnRestore() throws Exception {
final String id1 = "id1";
final String id2 = "id2";
final String state1 = "the-test-string-state";
final byte[] state1serialized = new StringSerializer().serialize(state1);
final long state2 = 987654321L;
final byte[] state2serialized = new LongSerializer().serialize(state2);
final List<MasterState> masterHookStates = Arrays.asList(
new MasterState(id1, state1serialized, StringSerializer.VERSION),
new MasterState(id2, state2serialized, LongSerializer.VERSION));
final MasterTriggerRestoreHook<String> statefulHook1 = mockGeneric(MasterTriggerRestoreHook.class);
when(statefulHook1.getIdentifier()).thenReturn(id1);
when(statefulHook1.createCheckpointDataSerializer()).thenReturn(new StringSerializer());
when(statefulHook1.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class)))
.thenThrow(new Exception("not expected"));
final MasterTriggerRestoreHook<Long> statefulHook2 = mockGeneric(MasterTriggerRestoreHook.class);
when(statefulHook2.getIdentifier()).thenReturn(id2);
when(statefulHook2.createCheckpointDataSerializer()).thenReturn(new LongSerializer());
when(statefulHook2.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class)))
.thenThrow(new Exception("not expected"));
final MasterTriggerRestoreHook<Void> statelessHook = mockGeneric(MasterTriggerRestoreHook.class);
when(statelessHook.getIdentifier()).thenReturn("some-id");
final JobID jid = new JobID();
final long checkpointId = 13L;
final CompletedCheckpoint checkpoint = new CompletedCheckpoint(
jid, checkpointId, 123L, 125L,
Collections.<OperatorID, OperatorState>emptyMap(),
masterHookStates,
CheckpointProperties.forStandardCheckpoint(),
null,
null);
final ExecutionAttemptID execId = new ExecutionAttemptID();
final ExecutionVertex ackVertex = mockExecutionVertex(execId);
final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex);
cc.addMasterHook(statefulHook1);
cc.addMasterHook(statelessHook);
cc.addMasterHook(statefulHook2);
cc.getCheckpointStore().addCheckpoint(checkpoint);
cc.restoreLatestCheckpointedState(
Collections.<JobVertexID, ExecutionJobVertex>emptyMap(),
true,
false);
verify(statefulHook1, times(1)).restoreCheckpoint(eq(checkpointId), eq(state1));
verify(statefulHook2, times(1)).restoreCheckpoint(eq(checkpointId), eq(state2));
verify(statelessHook, times(1)).restoreCheckpoint(eq(checkpointId), isNull(Void.class));
}
@Test
public void checkUnMatchedStateOnRestore() throws Exception {
final String id1 = "id1";
final String id2 = "id2";
final String state1 = "the-test-string-state";
final byte[] state1serialized = new StringSerializer().serialize(state1);
final long state2 = 987654321L;
final byte[] state2serialized = new LongSerializer().serialize(state2);
final List<MasterState> masterHookStates = Arrays.asList(
new MasterState(id1, state1serialized, StringSerializer.VERSION),
new MasterState(id2, state2serialized, LongSerializer.VERSION));
final MasterTriggerRestoreHook<String> statefulHook = mockGeneric(MasterTriggerRestoreHook.class);
when(statefulHook.getIdentifier()).thenReturn(id1);
when(statefulHook.createCheckpointDataSerializer()).thenReturn(new StringSerializer());
when(statefulHook.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class)))
.thenThrow(new Exception("not expected"));
final MasterTriggerRestoreHook<Void> statelessHook = mockGeneric(MasterTriggerRestoreHook.class);
when(statelessHook.getIdentifier()).thenReturn("some-id");
final JobID jid = new JobID();
final long checkpointId = 44L;
final CompletedCheckpoint checkpoint = new CompletedCheckpoint(
jid, checkpointId, 123L, 125L,
Collections.<OperatorID, OperatorState>emptyMap(),
masterHookStates,
CheckpointProperties.forStandardCheckpoint(),
null,
null);
final ExecutionAttemptID execId = new ExecutionAttemptID();
final ExecutionVertex ackVertex = mockExecutionVertex(execId);
final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex);
cc.addMasterHook(statefulHook);
cc.addMasterHook(statelessHook);
cc.getCheckpointStore().addCheckpoint(checkpoint);
// since we have unmatched state, this should fail
try {
cc.restoreLatestCheckpointedState(
Collections.<JobVertexID, ExecutionJobVertex>emptyMap(),
true,
false);
fail("exception expected");
}
catch (IllegalStateException ignored) {}
// permitting unmatched state should succeed
cc.restoreLatestCheckpointedState(
Collections.<JobVertexID, ExecutionJobVertex>emptyMap(),
true,
true);
verify(statefulHook, times(1)).restoreCheckpoint(eq(checkpointId), eq(state1));
verify(statelessHook, times(1)).restoreCheckpoint(eq(checkpointId), isNull(Void.class));
}
// ------------------------------------------------------------------------
// failure scenarios
// ------------------------------------------------------------------------
/**
* This test makes sure that the checkpoint is already registered by the time
* that the hooks are called
*/
@Test
public void ensureRegisteredAtHookTime() throws Exception {
final String id = "id";
// create the checkpoint coordinator
final JobID jid = new JobID();
final ExecutionAttemptID execId = new ExecutionAttemptID();
final ExecutionVertex ackVertex = mockExecutionVertex(execId);
final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex);
final MasterTriggerRestoreHook<Void> hook = mockGeneric(MasterTriggerRestoreHook.class);
when(hook.getIdentifier()).thenReturn(id);
when(hook.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class))).thenAnswer(
new Answer<Future<Void>>() {
@Override
public Future<Void> answer(InvocationOnMock invocation) throws Throwable {
assertEquals(1, cc.getNumberOfPendingCheckpoints());
long checkpointId = (Long) invocation.getArguments()[0];
assertNotNull(cc.getPendingCheckpoints().get(checkpointId));
return null;
}
}
);
cc.addMasterHook(hook);
// trigger a checkpoint
assertTrue(cc.triggerCheckpoint(System.currentTimeMillis(), false));
}
// ------------------------------------------------------------------------
// failure scenarios
// ------------------------------------------------------------------------
@Test
public void testSerializationFailsOnTrigger() {
}
@Test
public void testHookCallFailsOnTrigger() {
}
@Test
public void testDeserializationFailsOnRestore() {
}
@Test
public void testHookCallFailsOnRestore() {
}
@Test
public void testTypeIncompatibleWithSerializerOnStore() {
}
@Test
public void testTypeIncompatibleWithHookOnRestore() {
}
// ------------------------------------------------------------------------
// utilities
// ------------------------------------------------------------------------
private static CheckpointCoordinator instantiateCheckpointCoordinator(JobID jid, ExecutionVertex... ackVertices) {
return new CheckpointCoordinator(
jid,
10000000L,
600000L,
0L,
1,
ExternalizedCheckpointSettings.none(),
new ExecutionVertex[0],
ackVertices,
new ExecutionVertex[0],
new StandaloneCheckpointIDCounter(),
new StandaloneCompletedCheckpointStore(10),
null,
Executors.directExecutor());
}
private static <T> T mockGeneric(Class<?> clazz) {
@SuppressWarnings("unchecked")
Class<T> typedClass = (Class<T>) clazz;
return mock(typedClass);
}
// ------------------------------------------------------------------------
private static final class StringSerializer implements SimpleVersionedSerializer<String> {
static final int VERSION = 77;
@Override
public int getVersion() {
return VERSION;
}
@Override
public byte[] serialize(String checkpointData) throws IOException {
return checkpointData.getBytes(StandardCharsets.UTF_8);
}
@Override
public String deserialize(int version, byte[] serialized) throws IOException {
if (version != VERSION) {
throw new IOException("version mismatch");
}
return new String(serialized, StandardCharsets.UTF_8);
}
}
// ------------------------------------------------------------------------
private static final class LongSerializer implements SimpleVersionedSerializer<Long> {
static final int VERSION = 5;
@Override
public int getVersion() {
return VERSION;
}
@Override
public byte[] serialize(Long checkpointData) throws IOException {
final byte[] bytes = new byte[8];
ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).putLong(0, checkpointData);
return bytes;
}
@Override
public Long deserialize(int version, byte[] serialized) throws IOException {
assertEquals(VERSION, version);
assertEquals(8, serialized.length);
return ByteBuffer.wrap(serialized).order(ByteOrder.LITTLE_ENDIAN).getLong(0);
}
}
}