/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.test.checkpointing; import akka.actor.ActorRef; import akka.actor.ActorSystem; import akka.testkit.JavaTestKit; import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.JobSubmissionResult; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.restartstrategy.RestartStrategies; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.CoreOptions; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.checkpoint.OperatorState; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.savepoint.SavepointV2; import org.apache.flink.runtime.client.JobExecutionException; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.TaskInformation; import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings; import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.messages.JobManagerMessages.CancelJob; import org.apache.flink.runtime.messages.JobManagerMessages.DisposeSavepoint; import org.apache.flink.runtime.messages.JobManagerMessages.TriggerSavepoint; import org.apache.flink.runtime.messages.JobManagerMessages.TriggerSavepointSuccess; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.filesystem.FsStateBackend; import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory; import org.apache.flink.runtime.testingUtils.TestingCluster; import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.RequestSavepoint; import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.ResponseSavepoint; import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.WaitForAllVerticesToBeRunning; import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages; import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages.ResponseSubmitTaskListener; import org.apache.flink.runtime.testutils.CommonTestUtils; import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.IterativeStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.DiscardingSink; import org.apache.flink.streaming.api.functions.source.RichSourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.graph.StreamGraph; import org.apache.flink.util.Collector; import org.apache.flink.util.TestLogger; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Option; import scala.concurrent.Await; import scala.concurrent.Future; import scala.concurrent.duration.Deadline; import scala.concurrent.duration.FiniteDuration; import java.io.File; import java.io.FileNotFoundException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import static org.apache.flink.runtime.messages.JobManagerMessages.getDisposeSavepointSuccess; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** * Integration test for triggering and resuming from savepoints. */ @SuppressWarnings("serial") public class SavepointITCase extends TestLogger { private static final Logger LOG = LoggerFactory.getLogger(SavepointITCase.class); @Rule public TemporaryFolder folder = new TemporaryFolder(); /** * Triggers a savepoint for a job that uses the FsStateBackend. We expect * that all checkpoint files are written to a new savepoint directory. * * <ol> * <li>Submit job, wait for some progress</li> * <li>Trigger savepoint and verify that savepoint has been created</li> * <li>Shut down the cluster, re-submit the job from the savepoint, * verify that the initial state has been reset, and * all tasks are running again</li> * <li>Cancel job, dispose the savepoint, and verify that everything * has been cleaned up</li> * </ol> */ @Test public void testTriggerSavepointAndResumeWithFileBasedCheckpoints() throws Exception { // Config final int numTaskManagers = 2; final int numSlotsPerTaskManager = 2; final int parallelism = numTaskManagers * numSlotsPerTaskManager; final Deadline deadline = new FiniteDuration(5, TimeUnit.MINUTES).fromNow(); final File testRoot = folder.newFolder(); TestingCluster flink = null; try { // Create a test actor system ActorSystem testActorSystem = AkkaUtils.createDefaultActorSystem(); // Flink configuration final Configuration config = new Configuration(); config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, numTaskManagers); config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, numSlotsPerTaskManager); final File checkpointDir = new File(testRoot, "checkpoints"); final File savepointRootDir = new File(testRoot, "savepoints"); if (!checkpointDir.mkdir() || !savepointRootDir.mkdirs()) { fail("Test setup failed: failed to create temporary directories."); } // Use file based checkpoints config.setString(CoreOptions.STATE_BACKEND, "filesystem"); config.setString(FsStateBackendFactory.CHECKPOINT_DIRECTORY_URI_CONF_KEY, checkpointDir.toURI().toString()); config.setString(FsStateBackendFactory.MEMORY_THRESHOLD_CONF_KEY, "0"); config.setString(ConfigConstants.SAVEPOINT_DIRECTORY_KEY, savepointRootDir.toURI().toString()); // Start Flink flink = new TestingCluster(config); flink.start(true); // Submit the job final JobGraph jobGraph = createJobGraph(parallelism, 0, 1000); final JobID jobId = jobGraph.getJobID(); // Reset the static test job helpers StatefulCounter.resetForTest(parallelism); // Retrieve the job manager ActorGateway jobManager = Await.result(flink.leaderGateway().future(), deadline.timeLeft()); LOG.info("Submitting job " + jobGraph.getJobID() + " in detached mode."); flink.submitJobDetached(jobGraph); LOG.info("Waiting for some progress."); // wait for the JobManager to be ready Future<Object> allRunning = jobManager.ask(new WaitForAllVerticesToBeRunning(jobId), deadline.timeLeft()); Await.ready(allRunning, deadline.timeLeft()); // wait for the Tasks to be ready StatefulCounter.getProgressLatch().await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); LOG.info("Triggering a savepoint."); Future<Object> savepointPathFuture = jobManager.ask(new TriggerSavepoint(jobId, Option.<String>empty()), deadline.timeLeft()); final String savepointPath = ((TriggerSavepointSuccess) Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath(); LOG.info("Retrieved savepoint path: " + savepointPath + "."); // Retrieve the savepoint from the testing job manager LOG.info("Requesting the savepoint."); Future<Object> savepointFuture = jobManager.ask(new RequestSavepoint(savepointPath), deadline.timeLeft()); SavepointV2 savepoint = (SavepointV2) ((ResponseSavepoint) Await.result(savepointFuture, deadline.timeLeft())).savepoint(); LOG.info("Retrieved savepoint: " + savepointPath + "."); // Shut down the Flink cluster (thereby canceling the job) LOG.info("Shutting down Flink cluster."); flink.shutdown(); flink.awaitTermination(); // - Verification START ------------------------------------------- // Only one savepoint should exist File[] files = savepointRootDir.listFiles(); if (files != null) { assertEquals("Savepoint not created in expected directory", 1, files.length); assertTrue("Savepoint did not create self-contained directory", files[0].isDirectory()); File savepointDir = files[0]; File[] savepointFiles = savepointDir.listFiles(); assertNotNull(savepointFiles); // Expect one metadata file and one checkpoint file per stateful // parallel subtask String errMsg = "Did not write expected number of savepoint/checkpoint files to directory: " + Arrays.toString(savepointFiles); assertEquals(errMsg, 1 + parallelism, savepointFiles.length); } else { fail("Savepoint not created in expected directory"); } // We currently have the following directory layout: checkpointDir/jobId/chk-ID File jobCheckpoints = new File(checkpointDir, jobId.toString()); if (jobCheckpoints.exists()) { files = jobCheckpoints.listFiles(); assertNotNull("Checkpoint directory empty", files); assertEquals("Checkpoints directory not clean: " + Arrays.toString(files), 0, files.length); } // - Verification END --------------------------------------------- // Restart the cluster LOG.info("Restarting Flink cluster."); flink.start(); // Retrieve the job manager LOG.info("Retrieving JobManager."); jobManager = Await.result(flink.leaderGateway().future(), deadline.timeLeft()); LOG.info("JobManager: " + jobManager + "."); // Reset static test helpers StatefulCounter.resetForTest(parallelism); // Gather all task deployment descriptors final Throwable[] error = new Throwable[1]; final TestingCluster finalFlink = flink; final Multimap<JobVertexID, TaskDeploymentDescriptor> tdds = HashMultimap.create(); new JavaTestKit(testActorSystem) {{ new Within(deadline.timeLeft()) { @Override protected void run() { try { // Register to all submit task messages for job for (ActorRef taskManager : finalFlink.getTaskManagersAsJava()) { taskManager.tell(new TestingTaskManagerMessages .RegisterSubmitTaskListener(jobId), getTestActor()); } // Set the savepoint path jobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath)); LOG.info("Resubmitting job " + jobGraph.getJobID() + " with " + "savepoint path " + savepointPath + " in detached mode."); // Submit the job finalFlink.submitJobDetached(jobGraph); int numTasks = 0; for (JobVertex jobVertex : jobGraph.getVertices()) { numTasks += jobVertex.getParallelism(); } // Gather the task deployment descriptors LOG.info("Gathering " + numTasks + " submitted " + "TaskDeploymentDescriptor instances."); for (int i = 0; i < numTasks; i++) { ResponseSubmitTaskListener resp = (ResponseSubmitTaskListener) expectMsgAnyClassOf(getRemainingTime(), ResponseSubmitTaskListener.class); TaskDeploymentDescriptor tdd = resp.tdd(); LOG.info("Received: " + tdd.toString() + "."); TaskInformation taskInformation = tdd .getSerializedTaskInformation() .deserializeValue(getClass().getClassLoader()); tdds.put(taskInformation.getJobVertexId(), tdd); } } catch (Throwable t) { error[0] = t; } } }; }}; ExecutionGraph graph = (ExecutionGraph) ((JobManagerMessages.JobFound) Await.result(jobManager.ask(new JobManagerMessages.RequestJob(jobId), deadline.timeLeft()), deadline.timeLeft())).executionGraph(); // - Verification START ------------------------------------------- String errMsg = "Error during gathering of TaskDeploymentDescriptors"; if (error[0] != null) { throw new RuntimeException(error[0]); } Map<OperatorID, Tuple2<Integer, ExecutionJobVertex>> operatorToJobVertexMapping = new HashMap<>(); for (ExecutionJobVertex task : graph.getVerticesTopologically()) { List<OperatorID> operatorIDs = task.getOperatorIDs(); for (int x = 0; x < operatorIDs.size(); x++) { operatorToJobVertexMapping.put(operatorIDs.get(x), new Tuple2<>(x, task)); } } // Verify that all tasks, which are part of the savepoint // have a matching task deployment descriptor. for (OperatorState operatorState : savepoint.getOperatorStates()) { Tuple2<Integer, ExecutionJobVertex> chainIndexAndJobVertex = operatorToJobVertexMapping.get(operatorState.getOperatorID()); Collection<TaskDeploymentDescriptor> taskTdds = tdds.get(chainIndexAndJobVertex.f1.getJobVertexId()); errMsg = "Missing task for savepoint state for operator " + operatorState.getOperatorID() + "."; assertTrue(errMsg, taskTdds.size() > 0); assertEquals(operatorState.getNumberCollectedStates(), taskTdds.size()); for (TaskDeploymentDescriptor tdd : taskTdds) { OperatorSubtaskState subtaskState = operatorState.getState(tdd.getSubtaskIndex()); assertNotNull(subtaskState); errMsg = "Initial operator state mismatch."; assertEquals(errMsg, subtaskState.getLegacyOperatorState(), tdd.getTaskStateHandles().getLegacyOperatorState().get(chainIndexAndJobVertex.f0)); } } // Await state is restored StatefulCounter.getRestoreLatch().await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); // Await some progress after restore StatefulCounter.getProgressLatch().await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); // - Verification END --------------------------------------------- LOG.info("Cancelling job " + jobId + "."); jobManager.tell(new CancelJob(jobId)); LOG.info("Disposing savepoint " + savepointPath + "."); Future<Object> disposeFuture = jobManager.ask(new DisposeSavepoint(savepointPath), deadline.timeLeft()); errMsg = "Failed to dispose savepoint " + savepointPath + "."; Object resp = Await.result(disposeFuture, deadline.timeLeft()); assertTrue(errMsg, resp.getClass() == getDisposeSavepointSuccess().getClass()); // - Verification START ------------------------------------------- // The checkpoint files List<File> checkpointFiles = new ArrayList<>(); for (OperatorState stateForTaskGroup : savepoint.getOperatorStates()) { for (OperatorSubtaskState subtaskState : stateForTaskGroup.getStates()) { StreamStateHandle streamTaskState = subtaskState.getLegacyOperatorState(); if (streamTaskState != null) { FileStateHandle fileStateHandle = (FileStateHandle) streamTaskState; checkpointFiles.add(new File(fileStateHandle.getFilePath().toUri())); } } } // The checkpoint files of the savepoint should have been discarded for (File f : checkpointFiles) { errMsg = "Checkpoint file " + f + " not cleaned up properly."; assertFalse(errMsg, f.exists()); } if (checkpointFiles.size() > 0) { File parent = checkpointFiles.get(0).getParentFile(); errMsg = "Checkpoint parent directory " + parent + " not cleaned up properly."; assertFalse(errMsg, parent.exists()); } // All savepoints should have been cleaned up errMsg = "Savepoints directory not cleaned up properly: " + Arrays.toString(savepointRootDir.listFiles()) + "."; assertEquals(errMsg, 0, savepointRootDir.listFiles().length); // - Verification END --------------------------------------------- } finally { if (flink != null) { flink.shutdown(); } } } @Test public void testSubmitWithUnknownSavepointPath() throws Exception { // Config int numTaskManagers = 1; int numSlotsPerTaskManager = 1; int parallelism = numTaskManagers * numSlotsPerTaskManager; // Test deadline final Deadline deadline = new FiniteDuration(5, TimeUnit.MINUTES).fromNow(); final File tmpDir = CommonTestUtils.createTempDirectory(); final File savepointDir = new File(tmpDir, "savepoints"); TestingCluster flink = null; try { // Flink configuration final Configuration config = new Configuration(); config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, numTaskManagers); config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, numSlotsPerTaskManager); config.setString(ConfigConstants.SAVEPOINT_DIRECTORY_KEY, savepointDir.toURI().toString()); LOG.info("Flink configuration: " + config + "."); // Start Flink flink = new TestingCluster(config); LOG.info("Starting Flink cluster."); flink.start(); // Retrieve the job manager LOG.info("Retrieving JobManager."); ActorGateway jobManager = Await.result( flink.leaderGateway().future(), deadline.timeLeft()); LOG.info("JobManager: " + jobManager + "."); // High value to ensure timeouts if restarted. int numberOfRetries = 1000; // Submit the job // Long delay to ensure that the test times out if the job // manager tries to restart the job. final JobGraph jobGraph = createJobGraph(parallelism, numberOfRetries, 3600000); // Set non-existing savepoint path jobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath("unknown path")); assertEquals("unknown path", jobGraph.getSavepointRestoreSettings().getRestorePath()); LOG.info("Submitting job " + jobGraph.getJobID() + " in detached mode."); try { flink.submitJobAndWait(jobGraph, false); } catch (Exception e) { assertEquals(JobExecutionException.class, e.getClass()); assertEquals(FileNotFoundException.class, e.getCause().getClass()); } } finally { if (flink != null) { flink.shutdown(); } } } /** * FLINK-5985 * * This test ensures we can restore from a savepoint under modifications to the job graph that only concern * stateless operators. */ @Test public void testCanRestoreWithModifiedStatelessOperators() throws Exception { // Config int numTaskManagers = 2; int numSlotsPerTaskManager = 2; int parallelism = 2; // Test deadline final Deadline deadline = new FiniteDuration(5, TimeUnit.MINUTES).fromNow(); final File tmpDir = CommonTestUtils.createTempDirectory(); final File savepointDir = new File(tmpDir, "savepoints"); TestingCluster flink = null; String savepointPath; try { // Flink configuration final Configuration config = new Configuration(); config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, numTaskManagers); config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, numSlotsPerTaskManager); config.setString(ConfigConstants.SAVEPOINT_DIRECTORY_KEY, savepointDir.toURI().toString()); LOG.info("Flink configuration: " + config + "."); // Start Flink flink = new TestingCluster(config); LOG.info("Starting Flink cluster."); flink.start(true); // Retrieve the job manager LOG.info("Retrieving JobManager."); ActorGateway jobManager = Await.result( flink.leaderGateway().future(), deadline.timeLeft()); LOG.info("JobManager: " + jobManager + "."); final StatefulCounter statefulCounter = new StatefulCounter(); StatefulCounter.resetForTest(parallelism); StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); env.addSource(new InfiniteTestSource()) .shuffle() .map(new MapFunction<Integer, Integer>() { @Override public Integer map(Integer value) throws Exception { return 4 * value; } }) .shuffle() .map(statefulCounter).uid("statefulCounter") .shuffle() .map(new MapFunction<Integer, Integer>() { @Override public Integer map(Integer value) throws Exception { return 2 * value; } }) .addSink(new DiscardingSink<Integer>()); JobGraph originalJobGraph = env.getStreamGraph().getJobGraph(); JobSubmissionResult submissionResult = flink.submitJobDetached(originalJobGraph); JobID jobID = submissionResult.getJobID(); // wait for the Tasks to be ready StatefulCounter.getProgressLatch().await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); Future<Object> savepointPathFuture = jobManager.ask(new TriggerSavepoint(jobID, Option.<String>empty()), deadline.timeLeft()); savepointPath = ((TriggerSavepointSuccess) Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath(); Future<Object> savepointFuture = jobManager.ask(new RequestSavepoint(savepointPath), deadline.timeLeft()); ((ResponseSavepoint) Await.result(savepointFuture, deadline.timeLeft())).savepoint(); LOG.info("Retrieved savepoint: " + savepointPath + "."); // Shut down the Flink cluster (thereby canceling the job) LOG.info("Shutting down Flink cluster."); flink.shutdown(); flink.awaitTermination(); } finally { flink.shutdown(); flink.awaitTermination(); } try { LOG.info("Restarting Flink cluster."); flink.start(true); // Retrieve the job manager LOG.info("Retrieving JobManager."); ActorGateway jobManager = Await.result(flink.leaderGateway().future(), deadline.timeLeft()); LOG.info("JobManager: " + jobManager + "."); // Reset static test helpers StatefulCounter.resetForTest(parallelism); // Gather all task deployment descriptors StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); // generate a modified job graph that adds a stateless op env.addSource(new InfiniteTestSource()) .shuffle() .map(new StatefulCounter()).uid("statefulCounter") .shuffle() .map(new MapFunction<Integer, Integer>() { @Override public Integer map(Integer value) throws Exception { return value; } }) .addSink(new DiscardingSink<Integer>()); JobGraph modifiedJobGraph = env.getStreamGraph().getJobGraph(); // Set the savepoint path modifiedJobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath)); LOG.info("Resubmitting job " + modifiedJobGraph.getJobID() + " with " + "savepoint path " + savepointPath + " in detached mode."); // Submit the job flink.submitJobDetached(modifiedJobGraph); // Await state is restored StatefulCounter.getRestoreLatch().await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); // Await some progress after restore StatefulCounter.getProgressLatch().await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); } finally { flink.shutdown(); flink.awaitTermination(); } } // ------------------------------------------------------------------------ // Test program // ------------------------------------------------------------------------ /** * Creates a streaming JobGraph from the StreamEnvironment. */ private JobGraph createJobGraph( int parallelism, int numberOfRetries, long restartDelay) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); env.disableOperatorChaining(); env.getConfig().setRestartStrategy(RestartStrategies.fixedDelayRestart(numberOfRetries, restartDelay)); env.getConfig().disableSysoutLogging(); DataStream<Integer> stream = env .addSource(new InfiniteTestSource()) .shuffle() .map(new StatefulCounter()); stream.addSink(new DiscardingSink<Integer>()); return env.getStreamGraph().getJobGraph(); } private static class InfiniteTestSource implements SourceFunction<Integer> { private static final long serialVersionUID = 1L; private volatile boolean running = true; @Override public void run(SourceContext<Integer> ctx) throws Exception { while (running) { synchronized (ctx.getCheckpointLock()) { ctx.collect(1); } Thread.sleep(1); } } @Override public void cancel() { running = false; } } private static class StatefulCounter extends RichMapFunction<Integer, Integer> implements ListCheckpointed<byte[]>{ private static volatile CountDownLatch progressLatch = new CountDownLatch(0); private static volatile CountDownLatch restoreLatch = new CountDownLatch(0); private int numCollectedElements = 0; private static final long serialVersionUID = 7317800376639115920L; private byte[] data; @Override public void open(Configuration parameters) throws Exception { if (data == null) { // We need this to be large, because we want to test with files Random rand = new Random(getRuntimeContext().getIndexOfThisSubtask()); data = new byte[FsStateBackend.DEFAULT_FILE_STATE_THRESHOLD + 1]; rand.nextBytes(data); } } @Override public Integer map(Integer value) throws Exception { for (int i = 0; i < data.length; i++) { data[i] += 1; } if (numCollectedElements++ > 10) { progressLatch.countDown(); } return value; } @Override public List<byte[]> snapshotState(long checkpointId, long timestamp) throws Exception { return Collections.singletonList(data); } @Override public void restoreState(List<byte[]> state) throws Exception { if (state.isEmpty() || state.size() > 1) { throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size()); } this.data = state.get(0); restoreLatch.countDown(); } // -------------------------------------------------------------------- static CountDownLatch getProgressLatch() { return progressLatch; } static CountDownLatch getRestoreLatch() { return restoreLatch; } static void resetForTest(int parallelism) { progressLatch = new CountDownLatch(parallelism); restoreLatch = new CountDownLatch(parallelism); } } private static final int ITER_TEST_PARALLELISM = 1; private static OneShotLatch[] ITER_TEST_SNAPSHOT_WAIT = new OneShotLatch[ITER_TEST_PARALLELISM]; private static OneShotLatch[] ITER_TEST_RESTORE_WAIT = new OneShotLatch[ITER_TEST_PARALLELISM]; private static int[] ITER_TEST_CHECKPOINT_VERIFY = new int[ITER_TEST_PARALLELISM]; @Test public void testSavepointForJobWithIteration() throws Exception { for (int i = 0; i < ITER_TEST_PARALLELISM; ++i) { ITER_TEST_SNAPSHOT_WAIT[i] = new OneShotLatch(); ITER_TEST_RESTORE_WAIT[i] = new OneShotLatch(); ITER_TEST_CHECKPOINT_VERIFY[i] = 0; } TemporaryFolder folder = new TemporaryFolder(); folder.create(); // Temporary directory for file state backend final File tmpDir = folder.newFolder(); final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); final IntegerStreamSource source = new IntegerStreamSource(); IterativeStream<Integer> iteration = env.addSource(source) .flatMap(new RichFlatMapFunction<Integer, Integer>() { private static final long serialVersionUID = 1L; @Override public void flatMap(Integer in, Collector<Integer> clctr) throws Exception { clctr.collect(in); } }).setParallelism(ITER_TEST_PARALLELISM) .keyBy(new KeySelector<Integer, Object>() { private static final long serialVersionUID = 1L; @Override public Object getKey(Integer value) throws Exception { return value; } }) .flatMap(new DuplicateFilter()) .setParallelism(ITER_TEST_PARALLELISM) .iterate(); DataStream<Integer> iterationBody = iteration .map(new MapFunction<Integer, Integer>() { private static final long serialVersionUID = 1L; @Override public Integer map(Integer value) throws Exception { return value; } }) .setParallelism(ITER_TEST_PARALLELISM); iteration.closeWith(iterationBody); StreamGraph streamGraph = env.getStreamGraph(); streamGraph.setJobName("Test"); JobGraph jobGraph = streamGraph.getJobGraph(); Configuration config = new Configuration(); config.addAll(jobGraph.getJobConfiguration()); config.setLong(TaskManagerOptions.MANAGED_MEMORY_SIZE, -1L); config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, 2 * jobGraph.getMaximumParallelism()); final File checkpointDir = new File(tmpDir, "checkpoints"); final File savepointDir = new File(tmpDir, "savepoints"); if (!checkpointDir.mkdir() || !savepointDir.mkdirs()) { fail("Test setup failed: failed to create temporary directories."); } config.setString(CoreOptions.STATE_BACKEND, "filesystem"); config.setString(FsStateBackendFactory.CHECKPOINT_DIRECTORY_URI_CONF_KEY, checkpointDir.toURI().toString()); config.setString(FsStateBackendFactory.MEMORY_THRESHOLD_CONF_KEY, "0"); config.setString(ConfigConstants.SAVEPOINT_DIRECTORY_KEY, savepointDir.toURI().toString()); TestingCluster cluster = new TestingCluster(config, false); String savepointPath = null; try { cluster.start(); cluster.submitJobDetached(jobGraph); for (OneShotLatch latch : ITER_TEST_SNAPSHOT_WAIT) { latch.await(); } savepointPath = cluster.triggerSavepoint(jobGraph.getJobID()); source.cancel(); jobGraph = streamGraph.getJobGraph(); jobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath)); cluster.submitJobDetached(jobGraph); for (OneShotLatch latch : ITER_TEST_RESTORE_WAIT) { latch.await(); } source.cancel(); } finally { if (null != savepointPath) { cluster.disposeSavepoint(savepointPath); } cluster.stop(); cluster.awaitTermination(); } } private static final class IntegerStreamSource extends RichSourceFunction<Integer> implements ListCheckpointed<Integer> { private static final long serialVersionUID = 1L; private volatile boolean running; private volatile boolean isRestored; private int emittedCount; public IntegerStreamSource() { this.running = true; this.isRestored = false; this.emittedCount = 0; } @Override public void run(SourceContext<Integer> ctx) throws Exception { while (running) { synchronized (ctx.getCheckpointLock()) { ctx.collect(emittedCount); } if (emittedCount < 100) { ++emittedCount; } else { emittedCount = 0; } Thread.sleep(1); } } @Override public void cancel() { running = false; } @Override public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception { ITER_TEST_CHECKPOINT_VERIFY[getRuntimeContext().getIndexOfThisSubtask()] = emittedCount; return Collections.singletonList(emittedCount); } @Override public void restoreState(List<Integer> state) throws Exception { if (!state.isEmpty()) { this.emittedCount = state.get(0); } Assert.assertEquals(ITER_TEST_CHECKPOINT_VERIFY[getRuntimeContext().getIndexOfThisSubtask()], emittedCount); ITER_TEST_RESTORE_WAIT[getRuntimeContext().getIndexOfThisSubtask()].trigger(); } } public static class DuplicateFilter extends RichFlatMapFunction<Integer, Integer> { static final ValueStateDescriptor<Boolean> descriptor = new ValueStateDescriptor<>("seen", Boolean.class, false); private static final long serialVersionUID = 1L; private ValueState<Boolean> operatorState; @Override public void open(Configuration configuration) { operatorState = this.getRuntimeContext().getState(descriptor); } @Override public void flatMap(Integer value, Collector<Integer> out) throws Exception { if (!operatorState.value()) { out.collect(value); operatorState.update(true); } if (30 == value) { ITER_TEST_SNAPSHOT_WAIT[getRuntimeContext().getIndexOfThisSubtask()].trigger(); } } } }