/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.test.checkpointing; import org.apache.flink.api.common.functions.RichFilterFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RichReduceFunction; import org.apache.flink.api.common.restartstrategy.RestartStrategies; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction; import org.apache.flink.streaming.api.functions.sink.DiscardingSink; import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction; import org.apache.flink.streaming.api.functions.source.RichSourceFunction; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase; import org.apache.flink.util.Collector; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.concurrent.atomic.AtomicLong; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** * Integration test for the {@link CheckpointListener} interface. The test ensures that * {@link CheckpointListener#notifyCheckpointComplete(long)} is called for completed * checkpoints, that it is called at most once for any checkpoint id and that it is not * called for a deliberately failed checkpoint. * * <p> * The topology tested here includes a number of {@link OneInputStreamOperator}s and a * {@link TwoInputStreamOperator}. * * <p> * Note that as a result of doing the checks on the task level there is no way to verify * that the {@link CheckpointListener#notifyCheckpointComplete(long)} is called for every * successfully completed checkpoint. */ @SuppressWarnings("serial") public class StreamCheckpointNotifierITCase extends StreamingMultipleProgramsTestBase { private static final Logger LOG = LoggerFactory.getLogger(StreamCheckpointNotifierITCase.class); private static final int PARALLELISM = 4; /** * Runs the following program: * * <pre> * [ (source)->(filter) ] -> [ (co-map) ] -> [ (map) ] -> [ (groupBy/reduce)->(sink) ] * </pre> */ @Test public void testProgram() { try { final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); assertEquals("test setup broken", PARALLELISM, env.getParallelism()); env.enableCheckpointing(500); env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 0L)); final int numElements = 10000; final int numTaskTotal = PARALLELISM * 5; DataStream<Long> stream = env.addSource(new GeneratingSourceFunction(numElements, numTaskTotal)); stream // -------------- first vertex, chained to the src ---------------- .filter(new LongRichFilterFunction()) // -------------- second vertex, applying the co-map ---------------- .connect(stream).flatMap(new LeftIdentityCoRichFlatMapFunction()) // -------------- third vertex - the stateful one that also fails ---------------- .map(new IdentityMapFunction()) .startNewChain() // -------------- fourth vertex - reducer and the sink ---------------- .keyBy(0) .reduce(new OnceFailingReducer(numElements)) .addSink(new DiscardingSink<Tuple1<Long>>()); env.execute(); final long failureCheckpointID = OnceFailingReducer.failureCheckpointID; assertNotEquals(0L, failureCheckpointID); List<List<Long>[]> allLists = Arrays.asList( GeneratingSourceFunction.completedCheckpoints, LongRichFilterFunction.completedCheckpoints, LeftIdentityCoRichFlatMapFunction.completedCheckpoints, IdentityMapFunction.completedCheckpoints, OnceFailingReducer.completedCheckpoints ); for (List<Long>[] parallelNotifications : allLists) { for (List<Long> notifications : parallelNotifications) { assertTrue("No checkpoint notification was received.", notifications.size() > 0); assertFalse("Failure checkpoint was marked as completed.", notifications.contains(failureCheckpointID)); assertFalse("No checkpoint received after failure.", notifications.get(notifications.size() - 1) == failureCheckpointID); assertTrue("Checkpoint notification was received multiple times", notifications.size() == new HashSet<Long>(notifications).size()); } } } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } static List<Long>[] createCheckpointLists(int parallelism) { @SuppressWarnings({"unchecked", "rawtypes"}) List<Long>[] lists = new List[parallelism]; for (int i = 0; i < parallelism; i++) { lists[i] = new ArrayList<>(); } return lists; } // -------------------------------------------------------------------------------------------- // Custom Functions // -------------------------------------------------------------------------------------------- /** * Generates some Long values and as an implementation for the {@link CheckpointListener} * interface it stores all the checkpoint ids it has seen in a static list. */ private static class GeneratingSourceFunction extends RichSourceFunction<Long> implements ParallelSourceFunction<Long>, CheckpointListener, ListCheckpointed<Integer> { static final List<Long>[] completedCheckpoints = createCheckpointLists(PARALLELISM); static AtomicLong numPostFailureNotifications = new AtomicLong(); // operator behaviour private final long numElements; private final int notificationsToWaitFor; private int index; private int step; private volatile boolean notificationAlready; private volatile boolean isRunning = true; GeneratingSourceFunction(long numElements, int notificationsToWaitFor) { this.numElements = numElements; this.notificationsToWaitFor = notificationsToWaitFor; } @Override public void open(Configuration parameters) throws IOException { step = getRuntimeContext().getNumberOfParallelSubtasks(); // if index has been restored, it is not 0 any more if (index == 0) index = getRuntimeContext().getIndexOfThisSubtask(); } @Override public void run(SourceContext<Long> ctx) throws Exception { final Object lockingObject = ctx.getCheckpointLock(); while (isRunning && index < numElements) { long result = index % 10; synchronized (lockingObject) { index += step; ctx.collect(result); } } // if the program goes fast and no notifications come through, we // wait until all tasks had a chance to see a notification while (isRunning && numPostFailureNotifications.get() < notificationsToWaitFor) { Thread.sleep(50); } } @Override public void cancel() { isRunning = false; } @Override public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception { return Collections.singletonList(this.index); } @Override public void restoreState(List<Integer> state) throws Exception { if (state.isEmpty() || state.size() > 1) { throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size()); } this.index = state.get(0); } @Override public void notifyCheckpointComplete(long checkpointId) { // record the ID of the completed checkpoint int partition = getRuntimeContext().getIndexOfThisSubtask(); completedCheckpoints[partition].add(checkpointId); // if this is the first time we get a notification since the failure, // tell the source function if (OnceFailingReducer.hasFailed && !notificationAlready) { notificationAlready = true; GeneratingSourceFunction.numPostFailureNotifications.incrementAndGet(); } } } /** * Identity transform on Long values wrapping the output in a tuple. As an implementation * for the {@link CheckpointListener} interface it stores all the checkpoint ids it has seen in a static list. */ private static class IdentityMapFunction extends RichMapFunction<Long, Tuple1<Long>> implements CheckpointListener { static final List<Long>[] completedCheckpoints = createCheckpointLists(PARALLELISM); private volatile boolean notificationAlready; @Override public Tuple1<Long> map(Long value) throws Exception { return Tuple1.of(value); } @Override public void notifyCheckpointComplete(long checkpointId) { // record the ID of the completed checkpoint int partition = getRuntimeContext().getIndexOfThisSubtask(); completedCheckpoints[partition].add(checkpointId); // if this is the first time we get a notification since the failure, // tell the source function if (OnceFailingReducer.hasFailed && !notificationAlready) { notificationAlready = true; GeneratingSourceFunction.numPostFailureNotifications.incrementAndGet(); } } } /** * Filter on Long values supposedly letting all values through. As an implementation * for the {@link CheckpointListener} interface it stores all the checkpoint ids * it has seen in a static list. */ private static class LongRichFilterFunction extends RichFilterFunction<Long> implements CheckpointListener { static final List<Long>[] completedCheckpoints = createCheckpointLists(PARALLELISM); private volatile boolean notificationAlready; @Override public boolean filter(Long value) { return value < 100; } @Override public void notifyCheckpointComplete(long checkpointId) { // record the ID of the completed checkpoint int partition = getRuntimeContext().getIndexOfThisSubtask(); completedCheckpoints[partition].add(checkpointId); // if this is the first time we get a notification since the failure, // tell the source function if (OnceFailingReducer.hasFailed && !notificationAlready) { notificationAlready = true; GeneratingSourceFunction.numPostFailureNotifications.incrementAndGet(); } } } /** * CoFlatMap on Long values as identity transform on the left input, while ignoring the right. * As an implementation for the {@link CheckpointListener} interface it stores all the checkpoint * ids it has seen in a static list. */ private static class LeftIdentityCoRichFlatMapFunction extends RichCoFlatMapFunction<Long, Long, Long> implements CheckpointListener { static final List<Long>[] completedCheckpoints = createCheckpointLists(PARALLELISM); private volatile boolean notificationAlready; @Override public void flatMap1(Long value, Collector<Long> out) { out.collect(value); } @Override public void flatMap2(Long value, Collector<Long> out) { // we ignore the values from the second input } @Override public void notifyCheckpointComplete(long checkpointId) { // record the ID of the completed checkpoint int partition = getRuntimeContext().getIndexOfThisSubtask(); completedCheckpoints[partition].add(checkpointId); // if this is the first time we get a notification since the failure, // tell the source function if (OnceFailingReducer.hasFailed && !notificationAlready) { notificationAlready = true; GeneratingSourceFunction.numPostFailureNotifications.incrementAndGet(); } } } /** * Reducer that causes one failure between seeing 40% to 70% of the records. */ private static class OnceFailingReducer extends RichReduceFunction<Tuple1<Long>> implements ListCheckpointed<Long>, CheckpointListener { static volatile boolean hasFailed = false; static volatile long failureCheckpointID; static final List<Long>[] completedCheckpoints = createCheckpointLists(PARALLELISM); private final long failurePos; private volatile long count; private volatile boolean notificationAlready; OnceFailingReducer(long numElements) { this.failurePos = (long) (0.5 * numElements / PARALLELISM); } @Override public Tuple1<Long> reduce(Tuple1<Long> value1, Tuple1<Long> value2) { count++; if (count >= failurePos && getRuntimeContext().getIndexOfThisSubtask() == 0) { LOG.info(">>>>>>>>>>>>>>>>> Reached failing position <<<<<<<<<<<<<<<<<<<<<"); } value1.f0 += value2.f0; return value1; } @Override public List<Long> snapshotState(long checkpointId, long timestamp) throws Exception { if (!hasFailed && count >= failurePos && getRuntimeContext().getIndexOfThisSubtask() == 0) { LOG.info(">>>>>>>>>>>>>>>>> Throwing Exception <<<<<<<<<<<<<<<<<<<<<"); hasFailed = true; failureCheckpointID = checkpointId; throw new Exception("Test Failure"); } return Collections.singletonList(this.count); } @Override public void restoreState(List<Long> state) throws Exception { if (state.isEmpty() || state.size() > 1) { throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size()); } this.count = state.get(0); } @Override public void notifyCheckpointComplete(long checkpointId) { // record the ID of the completed checkpoint int partition = getRuntimeContext().getIndexOfThisSubtask(); completedCheckpoints[partition].add(checkpointId); // if this is the first time we get a notification since the failure, // tell the source function if (OnceFailingReducer.hasFailed && !notificationAlready) { notificationAlready = true; GeneratingSourceFunction.numPostFailureNotifications.incrementAndGet(); } } } }