/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.checkpointing;
import io.netty.util.internal.ConcurrentSet;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.CoreOptions;
import org.apache.flink.runtime.client.JobExecutionException;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.instance.ActorGateway;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
import org.apache.flink.runtime.messages.JobManagerMessages;
import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
import org.apache.flink.runtime.testingUtils.TestingCluster;
import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring;
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.util.Collector;
import org.apache.flink.util.TestLogger;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import scala.Option;
import scala.concurrent.Await;
import scala.concurrent.Future;
import scala.concurrent.duration.Deadline;
import scala.concurrent.duration.FiniteDuration;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
* TODO : parameterize to test all different state backends!
*/
public class RescalingITCase extends TestLogger {
private static final int numTaskManagers = 2;
private static final int slotsPerTaskManager = 2;
private static final int numSlots = numTaskManagers * slotsPerTaskManager;
enum OperatorCheckpointMethod {
NON_PARTITIONED, CHECKPOINTED_FUNCTION, CHECKPOINTED_FUNCTION_BROADCAST, LIST_CHECKPOINTED
}
private static TestingCluster cluster;
@ClassRule
public static TemporaryFolder temporaryFolder = new TemporaryFolder();
@BeforeClass
public static void setup() throws Exception {
Configuration config = new Configuration();
config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, numTaskManagers);
config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, slotsPerTaskManager);
final File checkpointDir = temporaryFolder.newFolder();
final File savepointDir = temporaryFolder.newFolder();
config.setString(CoreOptions.STATE_BACKEND, "filesystem");
config.setString(FsStateBackendFactory.CHECKPOINT_DIRECTORY_URI_CONF_KEY, checkpointDir.toURI().toString());
config.setString(ConfigConstants.SAVEPOINT_DIRECTORY_KEY, savepointDir.toURI().toString());
cluster = new TestingCluster(config);
cluster.start();
}
@AfterClass
public static void teardown() {
if (cluster != null) {
cluster.shutdown();
cluster.awaitTermination();
}
}
@Test
public void testSavepointRescalingInKeyedState() throws Exception {
testSavepointRescalingKeyedState(false, false);
}
@Test
public void testSavepointRescalingOutKeyedState() throws Exception {
testSavepointRescalingKeyedState(true, false);
}
@Test
public void testSavepointRescalingInKeyedStateDerivedMaxParallelism() throws Exception {
testSavepointRescalingKeyedState(false, true);
}
@Test
public void testSavepointRescalingOutKeyedStateDerivedMaxParallelism() throws Exception {
testSavepointRescalingKeyedState(true, true);
}
/**
* Tests that a job with purely keyed state can be restarted from a savepoint
* with a different parallelism.
*/
public void testSavepointRescalingKeyedState(boolean scaleOut, boolean deriveMaxParallelism) throws Exception {
final int numberKeys = 42;
final int numberElements = 1000;
final int numberElements2 = 500;
final int parallelism = scaleOut ? numSlots / 2 : numSlots;
final int parallelism2 = scaleOut ? numSlots : numSlots / 2;
final int maxParallelism = 13;
FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
Deadline deadline = timeout.fromNow();
ActorGateway jobManager = null;
JobID jobID = null;
try {
jobManager = cluster.getLeaderGateway(deadline.timeLeft());
JobGraph jobGraph = createJobGraphWithKeyedState(parallelism, maxParallelism, numberKeys, numberElements, false, 100);
jobID = jobGraph.getJobID();
cluster.submitJobDetached(jobGraph);
// wait til the sources have emitted numberElements for each key and completed a checkpoint
SubtaskIndexFlatMapper.workCompletedLatch.await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
// verify the current state
Set<Tuple2<Integer, Integer>> actualResult = CollectionSink.getElementsSet();
Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
for (int key = 0; key < numberKeys; key++) {
int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
expectedResult.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex), numberElements * key));
}
assertEquals(expectedResult, actualResult);
// clear the CollectionSink set for the restarted job
CollectionSink.clearElementsSet();
Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID, Option.<String>empty()), deadline.timeLeft());
final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)
Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
Future<Object> cancellationResponseFuture = jobManager.ask(new JobManagerMessages.CancelJob(jobID), deadline.timeLeft());
Object cancellationResponse = Await.result(cancellationResponseFuture, deadline.timeLeft());
assertTrue(cancellationResponse instanceof JobManagerMessages.CancellationSuccess);
Await.ready(jobRemovedFuture, deadline.timeLeft());
jobID = null;
int restoreMaxParallelism = deriveMaxParallelism ? ExecutionJobVertex.VALUE_NOT_SET : maxParallelism;
JobGraph scaledJobGraph = createJobGraphWithKeyedState(parallelism2, restoreMaxParallelism, numberKeys, numberElements2, true, 100);
scaledJobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath));
jobID = scaledJobGraph.getJobID();
cluster.submitJobAndWait(scaledJobGraph, false);
jobID = null;
Set<Tuple2<Integer, Integer>> actualResult2 = CollectionSink.getElementsSet();
Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
for (int key = 0; key < numberKeys; key++) {
int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
expectedResult2.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism2, keyGroupIndex), key * (numberElements + numberElements2)));
}
assertEquals(expectedResult2, actualResult2);
} finally {
// clear the CollectionSink set for the restarted job
CollectionSink.clearElementsSet();
// clear any left overs from a possibly failed job
if (jobID != null && jobManager != null) {
Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), timeout);
try {
Await.ready(jobRemovedFuture, timeout);
} catch (TimeoutException | InterruptedException ie) {
fail("Failed while cleaning up the cluster.");
}
}
}
}
/**
* Tests that a job cannot be restarted from a savepoint with a different parallelism if the
* rescaled operator has non-partitioned state.
*
* @throws Exception
*/
@Test
public void testSavepointRescalingNonPartitionedStateCausesException() throws Exception {
final int parallelism = numSlots / 2;
final int parallelism2 = numSlots;
final int maxParallelism = 13;
FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
Deadline deadline = timeout.fromNow();
JobID jobID = null;
ActorGateway jobManager = null;
try {
jobManager = cluster.getLeaderGateway(deadline.timeLeft());
JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, OperatorCheckpointMethod.NON_PARTITIONED);
jobID = jobGraph.getJobID();
cluster.submitJobDetached(jobGraph);
Object savepointResponse = null;
// wait until the operator is started
StateSourceBase.workStartedLatch.await();
Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID, Option.<String>empty()), deadline.timeLeft());
FiniteDuration waitingTime = new FiniteDuration(10, TimeUnit.SECONDS);
savepointResponse = Await.result(savepointPathFuture, waitingTime);
assertTrue(String.valueOf(savepointResponse), savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess);
final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess) savepointResponse).savepointPath();
Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
Future<Object> cancellationResponseFuture = jobManager.ask(new JobManagerMessages.CancelJob(jobID), deadline.timeLeft());
Object cancellationResponse = Await.result(cancellationResponseFuture, deadline.timeLeft());
assertTrue(cancellationResponse instanceof JobManagerMessages.CancellationSuccess);
Await.ready(jobRemovedFuture, deadline.timeLeft());
// job successfully removed
jobID = null;
JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, OperatorCheckpointMethod.NON_PARTITIONED);
scaledJobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath));
jobID = scaledJobGraph.getJobID();
cluster.submitJobAndWait(scaledJobGraph, false);
jobID = null;
} catch (JobExecutionException exception) {
if (exception.getCause() instanceof IllegalStateException) {
// we expect a IllegalStateException wrapped
// in a JobExecutionException, because the job containing non-partitioned state
// is being rescaled
} else {
throw exception;
}
} finally {
// clear any left overs from a possibly failed job
if (jobID != null && jobManager != null) {
Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), timeout);
try {
Await.ready(jobRemovedFuture, timeout);
} catch (TimeoutException | InterruptedException ie) {
fail("Failed while cleaning up the cluster.");
}
}
}
}
/**
* Tests that a job with non partitioned state can be restarted from a savepoint with a
* different parallelism if the operator with non-partitioned state are not rescaled.
*
* @throws Exception
*/
@Test
public void testSavepointRescalingWithKeyedAndNonPartitionedState() throws Exception {
int numberKeys = 42;
int numberElements = 1000;
int numberElements2 = 500;
int parallelism = numSlots / 2;
int parallelism2 = numSlots;
int maxParallelism = 13;
FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
Deadline deadline = timeout.fromNow();
ActorGateway jobManager = null;
JobID jobID = null;
try {
jobManager = cluster.getLeaderGateway(deadline.timeLeft());
JobGraph jobGraph = createJobGraphWithKeyedAndNonPartitionedOperatorState(
parallelism,
maxParallelism,
parallelism,
numberKeys,
numberElements,
false,
100);
jobID = jobGraph.getJobID();
cluster.submitJobDetached(jobGraph);
// wait til the sources have emitted numberElements for each key and completed a checkpoint
SubtaskIndexFlatMapper.workCompletedLatch.await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
// verify the current state
Set<Tuple2<Integer, Integer>> actualResult = CollectionSink.getElementsSet();
Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
for (int key = 0; key < numberKeys; key++) {
int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
expectedResult.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex), numberElements * key));
}
assertEquals(expectedResult, actualResult);
// clear the CollectionSink set for the restarted job
CollectionSink.clearElementsSet();
Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID, Option.<String>empty()), deadline.timeLeft());
final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)
Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
Future<Object> cancellationResponseFuture = jobManager.ask(new JobManagerMessages.CancelJob(jobID), deadline.timeLeft());
Object cancellationResponse = Await.result(cancellationResponseFuture, deadline.timeLeft());
assertTrue(cancellationResponse instanceof JobManagerMessages.CancellationSuccess);
Await.ready(jobRemovedFuture, deadline.timeLeft());
jobID = null;
JobGraph scaledJobGraph = createJobGraphWithKeyedAndNonPartitionedOperatorState(
parallelism2,
maxParallelism,
parallelism,
numberKeys,
numberElements + numberElements2,
true,
100);
scaledJobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath));
jobID = scaledJobGraph.getJobID();
cluster.submitJobAndWait(scaledJobGraph, false);
jobID = null;
Set<Tuple2<Integer, Integer>> actualResult2 = CollectionSink.getElementsSet();
Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
for (int key = 0; key < numberKeys; key++) {
int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
expectedResult2.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism2, keyGroupIndex), key * (numberElements + numberElements2)));
}
assertEquals(expectedResult2, actualResult2);
} finally {
// clear the CollectionSink set for the restarted job
CollectionSink.clearElementsSet();
// clear any left overs from a possibly failed job
if (jobID != null && jobManager != null) {
Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), timeout);
try {
Await.ready(jobRemovedFuture, timeout);
} catch (TimeoutException | InterruptedException ie) {
fail("Failed while cleaning up the cluster.");
}
}
}
}
@Test
public void testSavepointRescalingInPartitionedOperatorState() throws Exception {
testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
}
@Test
public void testSavepointRescalingOutPartitionedOperatorState() throws Exception {
testSavepointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
}
@Test
public void testSavepointRescalingInBroadcastOperatorState() throws Exception {
testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST);
}
@Test
public void testSavepointRescalingOutBroadcastOperatorState() throws Exception {
testSavepointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST);
}
@Test
public void testSavepointRescalingInPartitionedOperatorStateList() throws Exception {
testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.LIST_CHECKPOINTED);
}
@Test
public void testSavepointRescalingOutPartitionedOperatorStateList() throws Exception {
testSavepointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.LIST_CHECKPOINTED);
}
/**
* Tests rescaling of partitioned operator state. More specific, we test the mechanism with {@link ListCheckpointed}
* as it subsumes {@link org.apache.flink.streaming.api.checkpoint.CheckpointedFunction}.
*/
public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut, OperatorCheckpointMethod checkpointMethod) throws Exception {
final int parallelism = scaleOut ? numSlots : numSlots / 2;
final int parallelism2 = scaleOut ? numSlots / 2 : numSlots;
final int maxParallelism = 13;
FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
Deadline deadline = timeout.fromNow();
JobID jobID = null;
ActorGateway jobManager = null;
int counterSize = Math.max(parallelism, parallelism2);
if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION ||
checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST) {
PartitionedStateSource.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
PartitionedStateSource.CHECK_CORRECT_RESTORE = new int[counterSize];
} else {
PartitionedStateSourceListCheckpointed.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
PartitionedStateSourceListCheckpointed.CHECK_CORRECT_RESTORE = new int[counterSize];
}
try {
jobManager = cluster.getLeaderGateway(deadline.timeLeft());
JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, checkpointMethod);
jobID = jobGraph.getJobID();
cluster.submitJobDetached(jobGraph);
Object savepointResponse = null;
// wait until the operator is started
StateSourceBase.workStartedLatch.await();
while (deadline.hasTimeLeft()) {
Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID, Option.<String>empty()), deadline.timeLeft());
FiniteDuration waitingTime = new FiniteDuration(10, TimeUnit.SECONDS);
savepointResponse = Await.result(savepointPathFuture, waitingTime);
if (savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess) {
break;
}
System.out.println(savepointResponse);
}
assertTrue(savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess);
final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess) savepointResponse).savepointPath();
Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
Future<Object> cancellationResponseFuture = jobManager.ask(new JobManagerMessages.CancelJob(jobID), deadline.timeLeft());
Object cancellationResponse = Await.result(cancellationResponseFuture, deadline.timeLeft());
assertTrue(cancellationResponse instanceof JobManagerMessages.CancellationSuccess);
Await.ready(jobRemovedFuture, deadline.timeLeft());
// job successfully removed
jobID = null;
JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, checkpointMethod);
scaledJobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath));
jobID = scaledJobGraph.getJobID();
cluster.submitJobAndWait(scaledJobGraph, false);
int sumExp = 0;
int sumAct = 0;
if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) {
for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) {
sumExp += c;
}
for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
sumAct += c;
}
} else if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST) {
for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) {
sumExp += c;
}
for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
sumAct += c;
}
sumExp *= parallelism2;
} else {
for (int c : PartitionedStateSourceListCheckpointed.CHECK_CORRECT_SNAPSHOT) {
sumExp += c;
}
for (int c : PartitionedStateSourceListCheckpointed.CHECK_CORRECT_RESTORE) {
sumAct += c;
}
}
assertEquals(sumExp, sumAct);
jobID = null;
} finally {
// clear any left overs from a possibly failed job
if (jobID != null && jobManager != null) {
Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), timeout);
try {
Await.ready(jobRemovedFuture, timeout);
} catch (TimeoutException | InterruptedException ie) {
fail("Failed while cleaning up the cluster.");
}
}
}
}
//------------------------------------------------------------------------------------------------------------------
private static JobGraph createJobGraphWithOperatorState(
int parallelism, int maxParallelism, OperatorCheckpointMethod checkpointMethod) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(parallelism);
env.getConfig().setMaxParallelism(maxParallelism);
env.enableCheckpointing(Long.MAX_VALUE);
env.setRestartStrategy(RestartStrategies.noRestart());
StateSourceBase.workStartedLatch = new CountDownLatch(parallelism);
SourceFunction<Integer> src;
switch (checkpointMethod) {
case CHECKPOINTED_FUNCTION:
src = new PartitionedStateSource(false);
break;
case CHECKPOINTED_FUNCTION_BROADCAST:
src = new PartitionedStateSource(true);
break;
case LIST_CHECKPOINTED:
src = new PartitionedStateSourceListCheckpointed();
break;
case NON_PARTITIONED:
src = new NonPartitionedStateSource();
break;
default:
throw new IllegalArgumentException();
}
DataStream<Integer> input = env.addSource(src);
input.addSink(new DiscardingSink<Integer>());
return env.getStreamGraph().getJobGraph();
}
private static JobGraph createJobGraphWithKeyedState(
int parallelism,
int maxParallelism,
int numberKeys,
int numberElements,
boolean terminateAfterEmission,
int checkpointingInterval) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(parallelism);
if (0 < maxParallelism) {
env.getConfig().setMaxParallelism(maxParallelism);
}
env.enableCheckpointing(checkpointingInterval);
env.setRestartStrategy(RestartStrategies.noRestart());
DataStream<Integer> input = env.addSource(new SubtaskIndexSource(
numberKeys,
numberElements,
terminateAfterEmission))
.keyBy(new KeySelector<Integer, Integer>() {
private static final long serialVersionUID = -7952298871120320940L;
@Override
public Integer getKey(Integer value) throws Exception {
return value;
}
});
SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
DataStream<Tuple2<Integer, Integer>> result = input.flatMap(new SubtaskIndexFlatMapper(numberElements));
result.addSink(new CollectionSink<Tuple2<Integer, Integer>>());
return env.getStreamGraph().getJobGraph();
}
private static JobGraph createJobGraphWithKeyedAndNonPartitionedOperatorState(
int parallelism,
int maxParallelism,
int fixedParallelism,
int numberKeys,
int numberElements,
boolean terminateAfterEmission,
int checkpointingInterval) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(parallelism);
env.getConfig().setMaxParallelism(maxParallelism);
env.enableCheckpointing(checkpointingInterval);
env.setRestartStrategy(RestartStrategies.noRestart());
DataStream<Integer> input = env.addSource(new SubtaskIndexNonPartitionedStateSource(
numberKeys,
numberElements,
terminateAfterEmission))
.setParallelism(fixedParallelism)
.keyBy(new KeySelector<Integer, Integer>() {
private static final long serialVersionUID = -7952298871120320940L;
@Override
public Integer getKey(Integer value) throws Exception {
return value;
}
});
SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
DataStream<Tuple2<Integer, Integer>> result = input.flatMap(new SubtaskIndexFlatMapper(numberElements));
result.addSink(new CollectionSink<Tuple2<Integer, Integer>>());
return env.getStreamGraph().getJobGraph();
}
private static class SubtaskIndexSource
extends RichParallelSourceFunction<Integer> {
private static final long serialVersionUID = -400066323594122516L;
private final int numberKeys;
private final int numberElements;
private final boolean terminateAfterEmission;
protected int counter = 0;
private boolean running = true;
SubtaskIndexSource(
int numberKeys,
int numberElements,
boolean terminateAfterEmission) {
this.numberKeys = numberKeys;
this.numberElements = numberElements;
this.terminateAfterEmission = terminateAfterEmission;
}
@Override
public void run(SourceContext<Integer> ctx) throws Exception {
final Object lock = ctx.getCheckpointLock();
final int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
while (running) {
if (counter < numberElements) {
synchronized (lock) {
for (int value = subtaskIndex;
value < numberKeys;
value += getRuntimeContext().getNumberOfParallelSubtasks()) {
ctx.collect(value);
}
counter++;
}
} else {
if (terminateAfterEmission) {
running = false;
} else {
Thread.sleep(100);
}
}
}
}
@Override
public void cancel() {
running = false;
}
}
private static class SubtaskIndexNonPartitionedStateSource extends SubtaskIndexSource implements ListCheckpointed<Integer> {
private static final long serialVersionUID = 8388073059042040203L;
SubtaskIndexNonPartitionedStateSource(int numberKeys, int numberElements, boolean terminateAfterEmission) {
super(numberKeys, numberElements, terminateAfterEmission);
}
@Override
public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
return Collections.singletonList(this.counter);
}
@Override
public void restoreState(List<Integer> state) throws Exception {
if (state.isEmpty() || state.size() > 1) {
throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size());
}
this.counter = state.get(0);
}
}
private static class SubtaskIndexFlatMapper extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>> implements CheckpointedFunction {
private static final long serialVersionUID = 5273172591283191348L;
private static volatile CountDownLatch workCompletedLatch = new CountDownLatch(1);
private transient ValueState<Integer> counter;
private transient ValueState<Integer> sum;
private final int numberElements;
SubtaskIndexFlatMapper(int numberElements) {
this.numberElements = numberElements;
}
@Override
public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out) throws Exception {
int count = counter.value() + 1;
counter.update(count);
int s = sum.value() + value;
sum.update(s);
if (count % numberElements == 0) {
out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s));
workCompletedLatch.countDown();
}
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
//all managed, nothing to do.
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
counter = context.getKeyedStateStore().getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
sum = context.getKeyedStateStore().getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
}
}
private static class CollectionSink<IN> implements SinkFunction<IN> {
private static ConcurrentSet<Object> elements = new ConcurrentSet<Object>();
private static final long serialVersionUID = -1652452958040267745L;
public static <IN> Set<IN> getElementsSet() {
return (Set<IN>) elements;
}
public static void clearElementsSet() {
elements.clear();
}
@Override
public void invoke(IN value) throws Exception {
elements.add(value);
}
}
private static class StateSourceBase extends RichParallelSourceFunction<Integer> {
private static volatile CountDownLatch workStartedLatch = new CountDownLatch(1);
protected volatile int counter = 0;
protected volatile boolean running = true;
@Override
public void run(SourceContext<Integer> ctx) throws Exception {
final Object lock = ctx.getCheckpointLock();
while (running) {
synchronized (lock) {
++counter;
ctx.collect(1);
}
Thread.sleep(2);
if (counter == 10) {
workStartedLatch.countDown();
}
if (counter >= 500) {
break;
}
}
}
@Override
public void cancel() {
running = false;
}
}
private static class NonPartitionedStateSource extends StateSourceBase implements ListCheckpointed<Integer> {
private static final long serialVersionUID = -8108185918123186841L;
@Override
public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
return Collections.singletonList(this.counter);
}
@Override
public void restoreState(List<Integer> state) throws Exception {
if (!state.isEmpty()) {
this.counter = state.get(0);
}
}
}
private static class PartitionedStateSourceListCheckpointed extends StateSourceBase implements ListCheckpointed<Integer> {
private static final long serialVersionUID = -4357864582992546L;
private static final int NUM_PARTITIONS = 7;
private static int[] CHECK_CORRECT_SNAPSHOT;
private static int[] CHECK_CORRECT_RESTORE;
@Override
public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
CHECK_CORRECT_SNAPSHOT[getRuntimeContext().getIndexOfThisSubtask()] = counter;
int div = counter / NUM_PARTITIONS;
int mod = counter % NUM_PARTITIONS;
List<Integer> split = new ArrayList<>();
for (int i = 0; i < NUM_PARTITIONS; ++i) {
int partitionValue = div;
if (mod > 0) {
--mod;
++partitionValue;
}
split.add(partitionValue);
}
return split;
}
@Override
public void restoreState(List<Integer> state) throws Exception {
for (Integer v : state) {
counter += v;
}
CHECK_CORRECT_RESTORE[getRuntimeContext().getIndexOfThisSubtask()] = counter;
}
}
private static class PartitionedStateSource extends StateSourceBase implements CheckpointedFunction, CheckpointedRestoring<Integer> {
private static final long serialVersionUID = -359715965103593462L;
private static final int NUM_PARTITIONS = 7;
private ListState<Integer> counterPartitions;
private boolean broadcast;
private static int[] CHECK_CORRECT_SNAPSHOT;
private static int[] CHECK_CORRECT_RESTORE;
public PartitionedStateSource(boolean broadcast) {
this.broadcast = broadcast;
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
counterPartitions.clear();
CHECK_CORRECT_SNAPSHOT[getRuntimeContext().getIndexOfThisSubtask()] = counter;
int div = counter / NUM_PARTITIONS;
int mod = counter % NUM_PARTITIONS;
for (int i = 0; i < NUM_PARTITIONS; ++i) {
int partitionValue = div;
if (mod > 0) {
--mod;
++partitionValue;
}
counterPartitions.add(partitionValue);
}
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
if (broadcast) {
this.counterPartitions = context
.getOperatorStateStore()
.getUnionListState(new ListStateDescriptor<>("counter_partitions", IntSerializer.INSTANCE));
} else {
this.counterPartitions = context
.getOperatorStateStore()
.getListState(new ListStateDescriptor<>("counter_partitions", IntSerializer.INSTANCE));
}
if (context.isRestored()) {
for (int v : counterPartitions.get()) {
counter += v;
}
CHECK_CORRECT_RESTORE[getRuntimeContext().getIndexOfThisSubtask()] = counter;
}
}
@Override
public void restoreState(Integer state) throws Exception {
counterPartitions.add(state);
}
}
}