/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.contrib.streaming.state; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.api.common.typeutils.base.VoidSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.AsynchronousException; import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask; import org.apache.flink.streaming.runtime.tasks.OneInputStreamTaskTestHarness; import org.apache.flink.streaming.runtime.tasks.StreamMockEnvironment; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.util.FutureUtil; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.LocalFileSystem; import org.junit.Assert; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; import java.io.File; import java.io.IOException; import java.lang.reflect.Field; import java.net.URI; import java.util.Arrays; import java.util.UUID; import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.RunnableFuture; import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; /** * Tests for asynchronous RocksDB Key/Value state checkpoints. */ @RunWith(PowerMockRunner.class) @PrepareForTest({FileSystem.class}) @PowerMockIgnore({"javax.management.*", "com.sun.jndi.*", "org.apache.log4j.*"}) @SuppressWarnings("serial") public class RocksDBAsyncSnapshotTest { /** * This ensures that asynchronous state handles are actually materialized asynchronously. * * <p>We use latches to block at various stages and see if the code still continues through * the parts that are not asynchronous. If the checkpoint is not done asynchronously the * test will simply lock forever. */ @Test public void testFullyAsyncSnapshot() throws Exception { LocalFileSystem localFS = new LocalFileSystem(); localFS.initialize(new URI("file:///"), new Configuration()); PowerMockito.stub(PowerMockito.method(FileSystem.class, "get", URI.class, Configuration.class)).toReturn(localFS); final OneInputStreamTask<String, String> task = new OneInputStreamTask<>(); final OneInputStreamTaskTestHarness<String, String> testHarness = new OneInputStreamTaskTestHarness<>(task, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); testHarness.setupOutputForSingletonOperatorChain(); testHarness.configureForKeyedStream(new KeySelector<String, String>() { @Override public String getKey(String value) throws Exception { return value; } }, BasicTypeInfo.STRING_TYPE_INFO); StreamConfig streamConfig = testHarness.getStreamConfig(); File dbDir = new File(new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()), "state"); RocksDBStateBackend backend = new RocksDBStateBackend(new MemoryStateBackend()); backend.setDbStoragePath(dbDir.getAbsolutePath()); streamConfig.setStateBackend(backend); streamConfig.setStreamOperator(new AsyncCheckpointOperator()); final OneShotLatch delayCheckpointLatch = new OneShotLatch(); final OneShotLatch ensureCheckpointLatch = new OneShotLatch(); StreamMockEnvironment mockEnv = new StreamMockEnvironment( testHarness.jobConfig, testHarness.taskConfig, testHarness.memorySize, new MockInputSplitProvider(), testHarness.bufferSize) { @Override public void acknowledgeCheckpoint( long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState checkpointStateHandles) { super.acknowledgeCheckpoint(checkpointId, checkpointMetrics); // block on the latch, to verify that triggerCheckpoint returns below, // even though the async checkpoint would not finish try { delayCheckpointLatch.await(); } catch (InterruptedException e) { throw new RuntimeException(e); } // should be one k/v state assertNotNull(checkpointStateHandles.getManagedKeyedState()); // we now know that the checkpoint went through ensureCheckpointLatch.trigger(); } }; testHarness.invoke(mockEnv); // wait for the task to be running for (Field field: StreamTask.class.getDeclaredFields()) { if (field.getName().equals("isRunning")) { field.setAccessible(true); while (!field.getBoolean(task)) { Thread.sleep(10); } } } task.triggerCheckpoint(new CheckpointMetaData(42, 17), CheckpointOptions.forFullCheckpoint()); testHarness.processElement(new StreamRecord<>("Wohoo", 0)); // now we allow the checkpoint delayCheckpointLatch.trigger(); // wait for the checkpoint to go through ensureCheckpointLatch.await(); testHarness.endInput(); ExecutorService threadPool = task.getAsyncOperationsThreadPool(); threadPool.shutdown(); Assert.assertTrue(threadPool.awaitTermination(60_000, TimeUnit.MILLISECONDS)); testHarness.waitForTaskCompletion(); if (mockEnv.wasFailedExternally()) { fail("Unexpected exception during execution."); } } /** * This tests ensures that canceling of asynchronous snapshots works as expected and does not block. * @throws Exception */ @Test @Ignore public void testCancelFullyAsyncCheckpoints() throws Exception { LocalFileSystem localFS = new LocalFileSystem(); localFS.initialize(new URI("file:///"), new Configuration()); PowerMockito.stub(PowerMockito.method(FileSystem.class, "get", URI.class, Configuration.class)).toReturn(localFS); final OneInputStreamTask<String, String> task = new OneInputStreamTask<>(); final OneInputStreamTaskTestHarness<String, String> testHarness = new OneInputStreamTaskTestHarness<>(task, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); testHarness.setupOutputForSingletonOperatorChain(); testHarness.configureForKeyedStream(new KeySelector<String, String>() { @Override public String getKey(String value) throws Exception { return value; } }, BasicTypeInfo.STRING_TYPE_INFO); StreamConfig streamConfig = testHarness.getStreamConfig(); File dbDir = new File(new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()), "state"); BlockingStreamMemoryStateBackend memoryStateBackend = new BlockingStreamMemoryStateBackend(); RocksDBStateBackend backend = new RocksDBStateBackend(memoryStateBackend); backend.setDbStoragePath(dbDir.getAbsolutePath()); streamConfig.setStateBackend(backend); streamConfig.setStreamOperator(new AsyncCheckpointOperator()); StreamMockEnvironment mockEnv = new StreamMockEnvironment( testHarness.jobConfig, testHarness.taskConfig, testHarness.memorySize, new MockInputSplitProvider(), testHarness.bufferSize); BlockingStreamMemoryStateBackend.waitFirstWriteLatch = new OneShotLatch(); BlockingStreamMemoryStateBackend.unblockCancelLatch = new OneShotLatch(); testHarness.invoke(mockEnv); // wait for the task to be running for (Field field: StreamTask.class.getDeclaredFields()) { if (field.getName().equals("isRunning")) { field.setAccessible(true); while (!field.getBoolean(task)) { Thread.sleep(10); } } } task.triggerCheckpoint(new CheckpointMetaData(42, 17), CheckpointOptions.forFullCheckpoint()); testHarness.processElement(new StreamRecord<>("Wohoo", 0)); BlockingStreamMemoryStateBackend.waitFirstWriteLatch.await(); task.cancel(); BlockingStreamMemoryStateBackend.unblockCancelLatch.trigger(); testHarness.endInput(); try { ExecutorService threadPool = task.getAsyncOperationsThreadPool(); threadPool.shutdown(); Assert.assertTrue(threadPool.awaitTermination(60_000, TimeUnit.MILLISECONDS)); testHarness.waitForTaskCompletion(); if (mockEnv.wasFailedExternally()) { throw new AsynchronousException(new InterruptedException("Exception was thrown as expected.")); } fail("Operation completed. Cancel failed."); } catch (Exception expected) { AsynchronousException asynchronousException = null; if (expected instanceof AsynchronousException) { asynchronousException = (AsynchronousException) expected; } else if (expected.getCause() instanceof AsynchronousException) { asynchronousException = (AsynchronousException) expected.getCause(); } else { fail("Unexpected exception: " + expected); } // we expect the exception from canceling snapshots Throwable innerCause = asynchronousException.getCause(); Assert.assertTrue("Unexpected inner cause: " + innerCause, innerCause instanceof CancellationException //future canceled || innerCause instanceof InterruptedException); //thread interrupted } } /** * Test that the snapshot files are cleaned up in case of a failure during the snapshot * procedure. */ @Test public void testCleanupOfSnapshotsInFailureCase() throws Exception { long checkpointId = 1L; long timestamp = 42L; Environment env = new DummyEnvironment("test task", 1, 0); CheckpointStreamFactory.CheckpointStateOutputStream outputStream = mock(CheckpointStreamFactory.CheckpointStateOutputStream.class); CheckpointStreamFactory checkpointStreamFactory = mock(CheckpointStreamFactory.class); AbstractStateBackend stateBackend = mock(AbstractStateBackend.class); final IOException testException = new IOException("Test exception"); doReturn(checkpointStreamFactory).when(stateBackend).createStreamFactory(any(JobID.class), anyString()); doThrow(testException).when(outputStream).write(anyInt()); doReturn(outputStream).when(checkpointStreamFactory).createCheckpointStateOutputStream(eq(checkpointId), eq(timestamp)); RocksDBStateBackend backend = new RocksDBStateBackend(stateBackend); backend.setDbStoragePath("file:///tmp/foobar"); AbstractKeyedStateBackend<Void> keyedStateBackend = backend.createKeyedStateBackend( env, new JobID(), "test operator", VoidSerializer.INSTANCE, 1, new KeyGroupRange(0, 0), null); keyedStateBackend.restore(null); // register a state so that the state backend has to checkpoint something keyedStateBackend.getPartitionedState( "namespace", StringSerializer.INSTANCE, new ValueStateDescriptor<>("foobar", String.class)); RunnableFuture<KeyedStateHandle> snapshotFuture = keyedStateBackend.snapshot( checkpointId, timestamp, checkpointStreamFactory, CheckpointOptions.forFullCheckpoint()); try { FutureUtil.runIfNotDoneAndGet(snapshotFuture); fail("Expected an exception to be thrown here."); } catch (ExecutionException e) { Assert.assertEquals(testException, e.getCause()); } verify(outputStream).close(); } @Test public void testConsistentSnapshotSerializationFlagsAndMasks() { Assert.assertEquals(0xFFFF, RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.END_OF_KEY_GROUP_MARK); Assert.assertEquals(0x80, RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.FIRST_BIT_IN_BYTE_MASK); byte[] expectedKey = new byte[] {42, 42}; byte[] modKey = expectedKey.clone(); Assert.assertFalse( RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.hasMetaDataFollowsFlag(modKey)); RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.setMetaDataFollowsFlagInKey(modKey); Assert.assertTrue(RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.hasMetaDataFollowsFlag(modKey)); RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.clearMetaDataFollowsFlag(modKey); Assert.assertFalse( RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.hasMetaDataFollowsFlag(modKey)); Assert.assertTrue(Arrays.equals(expectedKey, modKey)); } // ------------------------------------------------------------------------ /** * Creates us a CheckpointStateOutputStream that blocks write ops on a latch to delay writing of snapshots. */ static class BlockingStreamMemoryStateBackend extends MemoryStateBackend { public static volatile OneShotLatch waitFirstWriteLatch = null; public static volatile OneShotLatch unblockCancelLatch = null; private volatile boolean closed = false; @Override public CheckpointStreamFactory createStreamFactory(JobID jobId, String operatorIdentifier) throws IOException { return new MemCheckpointStreamFactory(4 * 1024 * 1024) { @Override public CheckpointStateOutputStream createCheckpointStateOutputStream(long checkpointID, long timestamp) throws Exception { return new MemoryCheckpointOutputStream(4 * 1024 * 1024) { @Override public void write(int b) throws IOException { waitFirstWriteLatch.trigger(); try { unblockCancelLatch.await(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } if(closed) { throw new IOException("Stream closed."); } super.write(b); } @Override public void write(byte[] b, int off, int len) throws IOException { waitFirstWriteLatch.trigger(); try { unblockCancelLatch.await(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } if(closed) { throw new IOException("Stream closed."); } super.write(b, off, len); } @Override public void close() { closed = true; super.close(); } }; } }; } } public static class AsyncCheckpointOperator extends AbstractStreamOperator<String> implements OneInputStreamOperator<String, String>, StreamCheckpointedOperator { @Override public void open() throws Exception { super.open(); // also get the state in open, this way we are sure that it was created before // we trigger the test checkpoint ValueState<String> state = getPartitionedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, new ValueStateDescriptor<>("count", StringSerializer.INSTANCE)); } @Override public void processElement(StreamRecord<String> element) throws Exception { // we also don't care ValueState<String> state = getPartitionedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, new ValueStateDescriptor<>("count", StringSerializer.INSTANCE)); state.update(element.getValue()); } @Override public void snapshotState( FSDataOutputStream out, long checkpointId, long timestamp) throws Exception { // do nothing so that we don't block } @Override public void restoreState(FSDataInputStream in) throws Exception { // do nothing so that we don't block } } public static class DummyMapFunction<T> implements MapFunction<T, T> { @Override public T map(T value) { return value; } } }