/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.test.checkpointing.utils; import org.apache.flink.api.common.accumulators.IntCounter; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.contrib.streaming.state.RocksDBStateBackend; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.apache.flink.streaming.api.functions.source.RichSourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.TimestampedCollector; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.Collector; import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.assertEquals; /** * This verifies that we can restore a complete job from a Flink 1.1 savepoint. * * <p>The test pipeline contains both "Checkpointed" state and keyed user state. */ public class StatefulJobSavepointFrom11MigrationITCase extends SavepointMigrationTestBase { private static final int NUM_SOURCE_ELEMENTS = 4; private static final String EXPECTED_ELEMENTS_ACCUMULATOR = "NUM_EXPECTED_ELEMENTS"; private static final String SUCCESSFUL_CHECK_ACCUMULATOR = "SUCCESSFUL_CHECKS"; /** * This has to be manually executed to create the savepoint on Flink 1.1. */ @Test @Ignore public void testCreateSavepointOnFlink11() throws Exception { final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); // we only test memory state backend yet env.setStateBackend(new MemoryStateBackend()); env.enableCheckpointing(500); env.setParallelism(4); env.setMaxParallelism(4); // create source env .addSource(new LegacyCheckpointedSource(NUM_SOURCE_ELEMENTS)).setMaxParallelism(1).uid("LegacyCheckpointedSource") .flatMap(new LegacyCheckpointedFlatMap()).startNewChain().uid("LegacyCheckpointedFlatMap") .keyBy(0) .flatMap(new LegacyCheckpointedFlatMapWithKeyedState()).startNewChain().uid("LegacyCheckpointedFlatMapWithKeyedState") .keyBy(0) .flatMap(new KeyedStateSettingFlatMap()).startNewChain().uid("KeyedStateSettingFlatMap") .keyBy(0) .transform( "custom_operator", new TypeHint<Tuple2<Long, Long>>() {}.getTypeInfo(), new CheckpointedUdfOperator(new LegacyCheckpointedFlatMapWithKeyedState())).uid("LegacyCheckpointedOperator") .addSink(new AccumulatorCountingSink<Tuple2<Long, Long>>(EXPECTED_ELEMENTS_ACCUMULATOR)); executeAndSavepoint( env, "src/test/resources/stateful-udf-migration-itcase-flink1.1-savepoint", new Tuple2<>(EXPECTED_ELEMENTS_ACCUMULATOR, NUM_SOURCE_ELEMENTS)); } /** * This has to be manually executed to create the savepoint on Flink 1.1. */ @Test @Ignore public void testCreateSavepointOnFlink11WithRocksDB() throws Exception { final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); RocksDBStateBackend rocksBackend = new RocksDBStateBackend(new MemoryStateBackend()); // rocksBackend.enableFullyAsyncSnapshots(); env.setStateBackend(rocksBackend); env.enableCheckpointing(500); env.setParallelism(4); env.setMaxParallelism(4); // create source env .addSource(new LegacyCheckpointedSource(NUM_SOURCE_ELEMENTS)).setMaxParallelism(1).uid("LegacyCheckpointedSource") .flatMap(new LegacyCheckpointedFlatMap()).startNewChain().uid("LegacyCheckpointedFlatMap") .keyBy(0) .flatMap(new LegacyCheckpointedFlatMapWithKeyedState()).startNewChain().uid("LegacyCheckpointedFlatMapWithKeyedState") .keyBy(0) .flatMap(new KeyedStateSettingFlatMap()).startNewChain().uid("KeyedStateSettingFlatMap") .keyBy(0) .transform( "custom_operator", new TypeHint<Tuple2<Long, Long>>() {}.getTypeInfo(), new CheckpointedUdfOperator(new LegacyCheckpointedFlatMapWithKeyedState())).uid("LegacyCheckpointedOperator") .addSink(new AccumulatorCountingSink<Tuple2<Long, Long>>(EXPECTED_ELEMENTS_ACCUMULATOR)); executeAndSavepoint( env, "src/test/resources/stateful-udf-migration-itcase-flink1.1-rocksdb-savepoint", new Tuple2<>(EXPECTED_ELEMENTS_ACCUMULATOR, NUM_SOURCE_ELEMENTS)); } @Test public void testSavepointRestoreFromFlink11() throws Exception { final int EXPECTED_SUCCESSFUL_CHECKS = 21; final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); // we only test memory state backend yet env.setStateBackend(new MemoryStateBackend()); env.enableCheckpointing(500); env.setParallelism(4); env.setMaxParallelism(4); // create source env .addSource(new RestoringCheckingSource(NUM_SOURCE_ELEMENTS)).setMaxParallelism(1).uid("LegacyCheckpointedSource") .flatMap(new RestoringCheckingFlatMap()).startNewChain().uid("LegacyCheckpointedFlatMap") .keyBy(0) .flatMap(new RestoringCheckingFlatMapWithKeyedState()).startNewChain().uid("LegacyCheckpointedFlatMapWithKeyedState") .keyBy(0) .flatMap(new KeyedStateCheckingFlatMap()).startNewChain().uid("KeyedStateSettingFlatMap") .keyBy(0) .transform( "custom_operator", new TypeHint<Tuple2<Long, Long>>() {}.getTypeInfo(), new RestoringCheckingUdfOperator(new RestoringCheckingFlatMapWithKeyedState())).uid("LegacyCheckpointedOperator") .addSink(new AccumulatorCountingSink<Tuple2<Long, Long>>(EXPECTED_ELEMENTS_ACCUMULATOR)); restoreAndExecute( env, getResourceFilename("stateful-udf-migration-itcase-flink1.1-savepoint"), new Tuple2<>(SUCCESSFUL_CHECK_ACCUMULATOR, EXPECTED_SUCCESSFUL_CHECKS)); } @Test public void testSavepointRestoreFromFlink11FromRocksDB() throws Exception { final int EXPECTED_SUCCESSFUL_CHECKS = 21; final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); // we only test memory state backend yet env.setStateBackend(new RocksDBStateBackend(new MemoryStateBackend())); env.enableCheckpointing(500); env.setParallelism(4); env.setMaxParallelism(4); // create source env .addSource(new RestoringCheckingSource(NUM_SOURCE_ELEMENTS)).setMaxParallelism(1).uid("LegacyCheckpointedSource") .flatMap(new RestoringCheckingFlatMap()).startNewChain().uid("LegacyCheckpointedFlatMap") .keyBy(0) .flatMap(new RestoringCheckingFlatMapWithKeyedState()).startNewChain().uid("LegacyCheckpointedFlatMapWithKeyedState") .keyBy(0) .flatMap(new KeyedStateCheckingFlatMap()).startNewChain().uid("KeyedStateSettingFlatMap") .keyBy(0) .transform( "custom_operator", new TypeHint<Tuple2<Long, Long>>() {}.getTypeInfo(), new RestoringCheckingUdfOperator(new RestoringCheckingFlatMapWithKeyedState())).uid("LegacyCheckpointedOperator") .addSink(new AccumulatorCountingSink<Tuple2<Long, Long>>(EXPECTED_ELEMENTS_ACCUMULATOR)); restoreAndExecute( env, getResourceFilename("stateful-udf-migration-itcase-flink1.1-rocksdb-savepoint"), new Tuple2<>(SUCCESSFUL_CHECK_ACCUMULATOR, EXPECTED_SUCCESSFUL_CHECKS)); } private static class LegacyCheckpointedSource implements SourceFunction<Tuple2<Long, Long>>, Checkpointed<String> { public static String CHECKPOINTED_STRING = "Here be dragons!"; private static final long serialVersionUID = 1L; private volatile boolean isRunning = true; private final int numElements; public LegacyCheckpointedSource(int numElements) { this.numElements = numElements; } @Override public void run(SourceContext<Tuple2<Long, Long>> ctx) throws Exception { synchronized (ctx.getCheckpointLock()) { for (long i = 0; i < numElements; i++) { ctx.collect(new Tuple2<>(i, i)); } } while (isRunning) { Thread.sleep(20); } } @Override public void cancel() { isRunning = false; } @Override public void restoreState(String state) throws Exception { assertEquals(CHECKPOINTED_STRING, state); } @Override public String snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { return CHECKPOINTED_STRING; } } private static class RestoringCheckingSource extends RichSourceFunction<Tuple2<Long, Long>> implements CheckpointedRestoring<String> { private static final long serialVersionUID = 1L; private volatile boolean isRunning = true; private final int numElements; private String restoredState; public RestoringCheckingSource(int numElements) { this.numElements = numElements; } @Override public void open(Configuration parameters) throws Exception { super.open(parameters); getRuntimeContext().addAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR, new IntCounter()); } @Override public void run(SourceContext<Tuple2<Long, Long>> ctx) throws Exception { assertEquals(LegacyCheckpointedSource.CHECKPOINTED_STRING, restoredState); getRuntimeContext().getAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR).add(1); synchronized (ctx.getCheckpointLock()) { for (long i = 0; i < numElements; i++) { ctx.collect(new Tuple2<>(i, i)); } } while (isRunning) { Thread.sleep(20); } } @Override public void cancel() { isRunning = false; } @Override public void restoreState(String state) throws Exception { restoredState = state; } } public static class LegacyCheckpointedFlatMap extends RichFlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> implements Checkpointed<Tuple2<String, Long>> { private static final long serialVersionUID = 1L; public static Tuple2<String, Long> CHECKPOINTED_TUPLE = new Tuple2<>("hello", 42L); @Override public void flatMap(Tuple2<Long, Long> value, Collector<Tuple2<Long, Long>> out) throws Exception { out.collect(value); } @Override public void restoreState(Tuple2<String, Long> state) throws Exception { } @Override public Tuple2<String, Long> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { return CHECKPOINTED_TUPLE; } } public static class RestoringCheckingFlatMap extends RichFlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> implements CheckpointedRestoring<Tuple2<String, Long>> { private static final long serialVersionUID = 1L; private transient Tuple2<String, Long> restoredState; @Override public void open(Configuration parameters) throws Exception { super.open(parameters); getRuntimeContext().addAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR, new IntCounter()); } @Override public void flatMap(Tuple2<Long, Long> value, Collector<Tuple2<Long, Long>> out) throws Exception { out.collect(value); assertEquals(LegacyCheckpointedFlatMap.CHECKPOINTED_TUPLE, restoredState); getRuntimeContext().getAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR).add(1); } @Override public void restoreState(Tuple2<String, Long> state) throws Exception { restoredState = state; } } public static class LegacyCheckpointedFlatMapWithKeyedState extends RichFlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> implements Checkpointed<Tuple2<String, Long>> { private static final long serialVersionUID = 1L; public static Tuple2<String, Long> CHECKPOINTED_TUPLE = new Tuple2<>("hello", 42L); private final ValueStateDescriptor<Long> stateDescriptor = new ValueStateDescriptor<Long>("state-name", LongSerializer.INSTANCE); @Override public void flatMap(Tuple2<Long, Long> value, Collector<Tuple2<Long, Long>> out) throws Exception { out.collect(value); getRuntimeContext().getState(stateDescriptor).update(value.f1); } @Override public void restoreState(Tuple2<String, Long> state) throws Exception { } @Override public Tuple2<String, Long> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { return CHECKPOINTED_TUPLE; } } public static class RestoringCheckingFlatMapWithKeyedState extends RichFlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> implements CheckpointedRestoring<Tuple2<String, Long>> { private static final long serialVersionUID = 1L; private transient Tuple2<String, Long> restoredState; private final ValueStateDescriptor<Long> stateDescriptor = new ValueStateDescriptor<Long>("state-name", LongSerializer.INSTANCE); @Override public void open(Configuration parameters) throws Exception { super.open(parameters); getRuntimeContext().addAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR, new IntCounter()); } @Override public void flatMap(Tuple2<Long, Long> value, Collector<Tuple2<Long, Long>> out) throws Exception { out.collect(value); ValueState<Long> state = getRuntimeContext().getState(stateDescriptor); if (state == null) { throw new RuntimeException("Missing key value state for " + value); } assertEquals(value.f1, state.value()); assertEquals(LegacyCheckpointedFlatMap.CHECKPOINTED_TUPLE, restoredState); getRuntimeContext().getAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR).add(1); } @Override public void restoreState(Tuple2<String, Long> state) throws Exception { restoredState = state; } } public static class KeyedStateSettingFlatMap extends RichFlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> { private static final long serialVersionUID = 1L; private final ValueStateDescriptor<Long> stateDescriptor = new ValueStateDescriptor<Long>("state-name", LongSerializer.INSTANCE); @Override public void flatMap(Tuple2<Long, Long> value, Collector<Tuple2<Long, Long>> out) throws Exception { out.collect(value); getRuntimeContext().getState(stateDescriptor).update(value.f1); } } public static class KeyedStateCheckingFlatMap extends RichFlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> { private static final long serialVersionUID = 1L; private final ValueStateDescriptor<Long> stateDescriptor = new ValueStateDescriptor<Long>("state-name", LongSerializer.INSTANCE); @Override public void open(Configuration parameters) throws Exception { super.open(parameters); getRuntimeContext().addAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR, new IntCounter()); } @Override public void flatMap(Tuple2<Long, Long> value, Collector<Tuple2<Long, Long>> out) throws Exception { out.collect(value); ValueState<Long> state = getRuntimeContext().getState(stateDescriptor); if (state == null) { throw new RuntimeException("Missing key value state for " + value); } assertEquals(value.f1, state.value()); getRuntimeContext().getAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR).add(1); } } public static class CheckpointedUdfOperator extends AbstractUdfStreamOperator<Tuple2<Long, Long>, FlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>>> implements OneInputStreamOperator<Tuple2<Long, Long>, Tuple2<Long, Long>> { private static final long serialVersionUID = 1L; private static final String CHECKPOINTED_STRING = "Oh my, that's nice!"; public CheckpointedUdfOperator(FlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> userFunction) { super(userFunction); } @Override public void processElement(StreamRecord<Tuple2<Long, Long>> element) throws Exception { output.collect(element); } @Override public void processWatermark(Watermark mark) throws Exception { output.emitWatermark(mark); } // Flink 1.1 // @Override // public StreamTaskState snapshotOperatorState( // long checkpointId, long timestamp) throws Exception { // StreamTaskState result = super.snapshotOperatorState(checkpointId, timestamp); // // AbstractStateBackend.CheckpointStateOutputView out = getStateBackend().createCheckpointStateOutputView( // checkpointId, // timestamp); // // out.writeUTF(CHECKPOINTED_STRING); // // result.setOperatorState(out.closeAndGetHandle()); // // return result; // } } public static class RestoringCheckingUdfOperator extends AbstractUdfStreamOperator<Tuple2<Long, Long>, FlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>>> implements OneInputStreamOperator<Tuple2<Long, Long>, Tuple2<Long, Long>> { private static final long serialVersionUID = 1L; private String restoredState; public RestoringCheckingUdfOperator(FlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> userFunction) { super(userFunction); } @Override public void open() throws Exception { super.open(); } @Override public void processElement(StreamRecord<Tuple2<Long, Long>> element) throws Exception { userFunction.flatMap(element.getValue(), new TimestampedCollector<>(output)); assertEquals(CheckpointedUdfOperator.CHECKPOINTED_STRING, restoredState); getRuntimeContext().getAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR).add(1); } @Override public void processWatermark(Watermark mark) throws Exception { output.emitWatermark(mark); } @Override public void restoreState(FSDataInputStream in) throws Exception { super.restoreState(in); DataInputViewStreamWrapper streamWrapper = new DataInputViewStreamWrapper(in); restoredState = streamWrapper.readUTF(); } } public static class AccumulatorCountingSink<T> extends RichSinkFunction<T> { private static final long serialVersionUID = 1L; private final String accumulatorName; int count = 0; public AccumulatorCountingSink(String accumulatorName) { this.accumulatorName = accumulatorName; } @Override public void open(Configuration parameters) throws Exception { super.open(parameters); getRuntimeContext().addAccumulator(accumulatorName, new IntCounter()); } @Override public void invoke(T value) throws Exception { count++; getRuntimeContext().getAccumulator(accumulatorName).add(1); } } }