/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.test.classloading.jar; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.restartstrategy.RestartStrategies; import org.apache.flink.api.common.state.ReducingState; import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.runtime.state.filesystem.FsStateBackend; import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction; import org.apache.flink.test.util.SuccessException; import org.apache.flink.util.Collector; import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.concurrent.ThreadLocalRandom; public class CheckpointingCustomKvStateProgram { public static void main(String[] args) throws Exception { final String checkpointPath = args[0]; final String outputPath = args[1]; final int parallelism = 1; StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); env.getConfig().disableSysoutLogging(); env.enableCheckpointing(100); env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 1000)); env.setStateBackend(new FsStateBackend(checkpointPath)); DataStream<Integer> source = env.addSource(new InfiniteIntegerSource()); source .map(new MapFunction<Integer, Tuple2<Integer, Integer>>() { private static final long serialVersionUID = 1L; @Override public Tuple2<Integer, Integer> map(Integer value) throws Exception { return new Tuple2<>(ThreadLocalRandom.current().nextInt(parallelism), value); } }) .keyBy(new KeySelector<Tuple2<Integer,Integer>, Integer>() { private static final long serialVersionUID = 1L; @Override public Integer getKey(Tuple2<Integer, Integer> value) throws Exception { return value.f0; } }).flatMap(new ReducingStateFlatMap()).writeAsText(outputPath, FileSystem.WriteMode.OVERWRITE); env.execute(); } private static class InfiniteIntegerSource implements ParallelSourceFunction<Integer>, ListCheckpointed<Integer> { private static final long serialVersionUID = -7517574288730066280L; private volatile boolean running = true; @Override public void run(SourceContext<Integer> ctx) throws Exception { int counter = 0; while (running) { synchronized (ctx.getCheckpointLock()) { ctx.collect(counter++); } } } @Override public void cancel() { running = false; } @Override public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception { return Collections.singletonList(0); } @Override public void restoreState(List<Integer> state) throws Exception { } } private static class ReducingStateFlatMap extends RichFlatMapFunction<Tuple2<Integer, Integer>, Integer> implements ListCheckpointed<ReducingStateFlatMap>, CheckpointListener { private static final long serialVersionUID = -5939722892793950253L; private transient ReducingState<Integer> kvState; private boolean atLeastOneSnapshotComplete = false; private boolean restored = false; @Override public void open(Configuration parameters) throws Exception { ReducingStateDescriptor<Integer> stateDescriptor = new ReducingStateDescriptor<>( "reducing-state", new ReduceSum(), CustomIntSerializer.INSTANCE); this.kvState = getRuntimeContext().getReducingState(stateDescriptor); } @Override public void flatMap(Tuple2<Integer, Integer> value, Collector<Integer> out) throws Exception { kvState.add(value.f1); if(atLeastOneSnapshotComplete) { if (restored) { throw new SuccessException(); } else { throw new RuntimeException("Intended failure, to trigger restore"); } } } @Override public List<ReducingStateFlatMap> snapshotState(long checkpointId, long timestamp) throws Exception { return Collections.singletonList(this); } @Override public void restoreState(List<ReducingStateFlatMap> state) throws Exception { restored = true; atLeastOneSnapshotComplete = true; } @Override public void notifyCheckpointComplete(long checkpointId) throws Exception { atLeastOneSnapshotComplete = true; } private static class ReduceSum implements ReduceFunction<Integer> { private static final long serialVersionUID = 1L; @Override public Integer reduce(Integer value1, Integer value2) throws Exception { return value1 + value2; } } } private static final class CustomIntSerializer extends TypeSerializerSingleton<Integer> { private static final long serialVersionUID = 4572452915892737448L; public static final TypeSerializer<Integer> INSTANCE = new CustomIntSerializer(); @Override public boolean isImmutableType() { return true; } @Override public Integer createInstance() { return 0; } @Override public Integer copy(Integer from) { return from; } @Override public Integer copy(Integer from, Integer reuse) { return from; } @Override public int getLength() { return 4; } @Override public void serialize(Integer record, DataOutputView target) throws IOException { target.writeInt(record.intValue()); } @Override public Integer deserialize(DataInputView source) throws IOException { return Integer.valueOf(source.readInt()); } @Override public Integer deserialize(Integer reuse, DataInputView source) throws IOException { return Integer.valueOf(source.readInt()); } @Override public void copy(DataInputView source, DataOutputView target) throws IOException { target.writeInt(source.readInt()); } @Override public boolean canEqual(Object obj) { return obj instanceof CustomIntSerializer; } } }