/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.test.checkpointing; import org.apache.flink.api.common.functions.RichFilterFunction; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; import org.apache.flink.util.Collector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * A simple test that runs a streaming topology with checkpointing enabled. * * The test triggers a failure after a while and verifies that, after completion, the * state defined with either the {@link ValueState} or the {@link ListCheckpointed} * interface reflects the "exactly once" semantics. * * The test throttles the input until at least two checkpoints are completed, to make sure that * the recovery does not fall back to "square one" (which would naturally lead to correct * results without testing the checkpointing). */ @SuppressWarnings("serial") public class StateCheckpointedITCase extends StreamFaultToleranceTestBase { private static final Logger LOG = LoggerFactory.getLogger(StateCheckpointedITCase.class); final long NUM_STRINGS = 10_000_000L; /** * Runs the following program: * * <pre> * [ (source)->(filter)] -> [ (map) -> (map) ] -> [ (groupBy/reduce)->(sink) ] * </pre> */ @Override public void testProgram(StreamExecutionEnvironment env) { assertTrue("Broken test setup", NUM_STRINGS % 40 == 0); final long failurePosMin = (long) (0.4 * NUM_STRINGS / PARALLELISM); final long failurePosMax = (long) (0.7 * NUM_STRINGS / PARALLELISM); final long failurePos = (new Random().nextLong() % (failurePosMax - failurePosMin)) + failurePosMin; env.enableCheckpointing(200); DataStream<String> stream = env.addSource(new StringGeneratingSourceFunction(NUM_STRINGS)); stream // first vertex, chained to the source // this filter throttles the flow until at least one checkpoint // is complete, to make sure this program does not run without .filter(new StringRichFilterFunction()) // -------------- seconds vertex - one-to-one connected ---------------- .map(new StringPrefixCountRichMapFunction()) .startNewChain() .map(new StatefulCounterFunction()) // -------------- third vertex - reducer and the sink ---------------- .keyBy("prefix") .flatMap(new OnceFailingAggregator(failurePos)) .addSink(new ValidatingSink()); } @Override public void postSubmit() { //assertTrue("Test inconclusive: failure occurred before first checkpoint", // OnceFailingAggregator.wasCheckpointedBeforeFailure); if(!OnceFailingAggregator.wasCheckpointedBeforeFailure) { LOG.warn("Test inconclusive: failure occurred before first checkpoint"); } long filterSum = 0; for (long l : StringRichFilterFunction.counts) { filterSum += l; } long mapSum = 0; for (long l : StringPrefixCountRichMapFunction.counts) { mapSum += l; } long countSum = 0; for (long l : StatefulCounterFunction.counts) { countSum += l; } // verify that we counted exactly right assertEquals(NUM_STRINGS, filterSum); assertEquals(NUM_STRINGS, mapSum); assertEquals(NUM_STRINGS, countSum); for (Map<Character, Long> map : ValidatingSink.maps) { for (Long count : map.values()) { assertEquals(NUM_STRINGS / 40, count.longValue()); } } } // -------------------------------------------------------------------------------------------- // Custom Functions // -------------------------------------------------------------------------------------------- private static class StringGeneratingSourceFunction extends RichParallelSourceFunction<String> implements ListCheckpointed<Integer> { private final long numElements; private int index; private volatile boolean isRunning = true; StringGeneratingSourceFunction(long numElements) { this.numElements = numElements; } @Override public void run(SourceContext<String> ctx) throws Exception { final Object lockingObject = ctx.getCheckpointLock(); final Random rnd = new Random(); final StringBuilder stringBuilder = new StringBuilder(); final int step = getRuntimeContext().getNumberOfParallelSubtasks(); if (index == 0) { index = getRuntimeContext().getIndexOfThisSubtask(); } while (isRunning && index < numElements) { char first = (char) ((index % 40) + 40); stringBuilder.setLength(0); stringBuilder.append(first); String result = randomString(stringBuilder, rnd); synchronized (lockingObject) { index += step; ctx.collect(result); } } } @Override public void cancel() { isRunning = false; } private static String randomString(StringBuilder bld, Random rnd) { final int len = rnd.nextInt(10) + 5; for (int i = 0; i < len; i++) { char next = (char) (rnd.nextInt(20000) + 33); bld.append(next); } return bld.toString(); } @Override public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception { return Collections.singletonList(this.index); } @Override public void restoreState(List<Integer> state) throws Exception { if (state.isEmpty() || state.size() > 1) { throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size()); } this.index = state.get(0); } } private static class StringRichFilterFunction extends RichFilterFunction<String> implements ListCheckpointed<Long> { static final long[] counts = new long[PARALLELISM]; private long count; @Override public boolean filter(String value) throws Exception { count++; return value.length() < 100; // should be always true } @Override public void close() { counts[getRuntimeContext().getIndexOfThisSubtask()] = count; } @Override public List<Long> snapshotState(long checkpointId, long timestamp) throws Exception { return Collections.singletonList(this.count); } @Override public void restoreState(List<Long> state) throws Exception { if (state.isEmpty() || state.size() > 1) { throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size()); } this.count = state.get(0); } } private static class StringPrefixCountRichMapFunction extends RichMapFunction<String, PrefixCount> implements ListCheckpointed<Long> { static final long[] counts = new long[PARALLELISM]; private long count; @Override public PrefixCount map(String value) { count++; return new PrefixCount(value.substring(0, 1), value, 1L); } @Override public void close() { counts[getRuntimeContext().getIndexOfThisSubtask()] = count; } @Override public List<Long> snapshotState(long checkpointId, long timestamp) throws Exception { return Collections.singletonList(this.count); } @Override public void restoreState(List<Long> state) throws Exception { if (state.isEmpty() || state.size() > 1) { throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size()); } this.count = state.get(0); } } private static class StatefulCounterFunction extends RichMapFunction<PrefixCount, PrefixCount> implements ListCheckpointed<Long> { static final long[] counts = new long[PARALLELISM]; private long count; @Override public PrefixCount map(PrefixCount value) throws Exception { count++; return value; } @Override public void close() throws IOException { counts[getRuntimeContext().getIndexOfThisSubtask()] = count; } @Override public List<Long> snapshotState(long checkpointId, long timestamp) throws Exception { return Collections.singletonList(this.count); } @Override public void restoreState(List<Long> state) throws Exception { if (state.isEmpty() || state.size() > 1) { throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size()); } this.count = state.get(0); } } private static class OnceFailingAggregator extends RichFlatMapFunction<PrefixCount, PrefixCount> implements ListCheckpointed<HashMap<String, PrefixCount>>, CheckpointListener { static boolean wasCheckpointedBeforeFailure = false; private static volatile boolean hasFailed = false; private final HashMap<String, PrefixCount> aggregationMap = new HashMap<String, PrefixCount>(); private long failurePos; private long count; private boolean wasCheckpointed; OnceFailingAggregator(long failurePos) { this.failurePos = failurePos; } @Override public void open(Configuration parameters) { count = 0; } @Override public void flatMap(PrefixCount value, Collector<PrefixCount> out) throws Exception { count++; if (!hasFailed && count >= failurePos && getRuntimeContext().getIndexOfThisSubtask() == 1) { wasCheckpointedBeforeFailure = wasCheckpointed; hasFailed = true; throw new Exception("Test Failure"); } PrefixCount curr = aggregationMap.get(value.prefix); if (curr == null) { aggregationMap.put(value.prefix, value); out.collect(value); } else { curr.count += value.count; out.collect(curr); } } @Override public List<HashMap<String, PrefixCount>> snapshotState(long checkpointId, long timestamp) throws Exception { return Collections.singletonList(this.aggregationMap); } @Override public void restoreState(List<HashMap<String, PrefixCount>> state) throws Exception { if (state.isEmpty() || state.size() > 1) { throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size()); } this.aggregationMap.putAll(state.get(0)); } @Override public void notifyCheckpointComplete(long checkpointId) { this.wasCheckpointed = true; } } private static class ValidatingSink extends RichSinkFunction<PrefixCount> implements ListCheckpointed<HashMap<Character, Long>> { @SuppressWarnings("unchecked") private static Map<Character, Long>[] maps = (Map<Character, Long>[]) new Map<?, ?>[PARALLELISM]; private HashMap<Character, Long> counts = new HashMap<Character, Long>(); @Override public void invoke(PrefixCount value) { Character first = value.prefix.charAt(0); Long previous = counts.get(first); if (previous == null) { counts.put(first, value.count); } else { counts.put(first, Math.max(previous, value.count)); } } @Override public void close() throws Exception { maps[getRuntimeContext().getIndexOfThisSubtask()] = counts; } @Override public List<HashMap<Character, Long>> snapshotState(long checkpointId, long timestamp) throws Exception { return Collections.singletonList(this.counts); } @Override public void restoreState(List<HashMap<Character, Long>> state) throws Exception { if (state.isEmpty() || state.size() > 1) { throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size()); } this.counts.putAll(state.get(0)); } } }