/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.contrib.streaming.state; import org.apache.commons.io.FileUtils; import org.apache.commons.io.filefilter.IOFileFilter; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.IncrementalKeyedStateHandle; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StateBackendTestBase; import org.apache.flink.runtime.state.StateHandleID; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.filesystem.FsStateBackend; import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.rocksdb.ColumnFamilyDescriptor; import org.rocksdb.ColumnFamilyHandle; import org.rocksdb.ReadOptions; import org.rocksdb.RocksDB; import org.rocksdb.RocksIterator; import org.rocksdb.RocksObject; import org.rocksdb.Snapshot; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Queue; import java.util.concurrent.RunnableFuture; import static junit.framework.TestCase.assertNotNull; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.internal.verification.VerificationModeFactory.times; import static org.powermock.api.mockito.PowerMockito.mock; import static org.powermock.api.mockito.PowerMockito.spy; /** * Tests for the partitioned state part of {@link RocksDBStateBackend}. */ @RunWith(Parameterized.class) public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBackend> { private OneShotLatch blocker; private OneShotLatch waiter; private BlockerCheckpointStreamFactory testStreamFactory; private RocksDBKeyedStateBackend<Integer> keyedStateBackend; private List<RocksObject> allCreatedCloseables; private ValueState<Integer> testState1; private ValueState<String> testState2; @Parameterized.Parameters public static Collection<Boolean> parameters() { return Arrays.asList(false, true); } @Parameterized.Parameter public boolean enableIncrementalCheckpointing; @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); // Store it because we need it for the cleanup test. String dbPath; @Override protected RocksDBStateBackend getStateBackend() throws IOException { dbPath = tempFolder.newFolder().getAbsolutePath(); String checkpointPath = tempFolder.newFolder().toURI().toString(); RocksDBStateBackend backend = new RocksDBStateBackend(new FsStateBackend(checkpointPath), enableIncrementalCheckpointing); backend.setDbStoragePath(dbPath); return backend; } public void setupRocksKeyedStateBackend() throws Exception { blocker = new OneShotLatch(); waiter = new OneShotLatch(); testStreamFactory = new BlockerCheckpointStreamFactory(1024 * 1024); testStreamFactory.setBlockerLatch(blocker); testStreamFactory.setWaiterLatch(waiter); testStreamFactory.setAfterNumberInvocations(10); RocksDBStateBackend backend = getStateBackend(); Environment env = new DummyEnvironment("TestTask", 1, 0); keyedStateBackend = (RocksDBKeyedStateBackend<Integer>) backend.createKeyedStateBackend( env, new JobID(), "Test", IntSerializer.INSTANCE, 2, new KeyGroupRange(0, 1), mock(TaskKvStateRegistry.class)); keyedStateBackend.restore(null); testState1 = keyedStateBackend.getPartitionedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, new ValueStateDescriptor<>("TestState-1", Integer.class, 0)); testState2 = keyedStateBackend.getPartitionedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, new ValueStateDescriptor<>("TestState-2", String.class, "")); allCreatedCloseables = new ArrayList<>(); keyedStateBackend.db = spy(keyedStateBackend.db); doAnswer(new Answer<Object>() { @Override public Object answer(InvocationOnMock invocationOnMock) throws Throwable { RocksIterator rocksIterator = spy((RocksIterator) invocationOnMock.callRealMethod()); allCreatedCloseables.add(rocksIterator); return rocksIterator; } }).when(keyedStateBackend.db).newIterator(any(ColumnFamilyHandle.class), any(ReadOptions.class)); doAnswer(new Answer<Object>() { @Override public Object answer(InvocationOnMock invocationOnMock) throws Throwable { Snapshot snapshot = spy((Snapshot) invocationOnMock.callRealMethod()); allCreatedCloseables.add(snapshot); return snapshot; } }).when(keyedStateBackend.db).getSnapshot(); doAnswer(new Answer<Object>() { @Override public Object answer(InvocationOnMock invocationOnMock) throws Throwable { ColumnFamilyHandle snapshot = spy((ColumnFamilyHandle) invocationOnMock.callRealMethod()); allCreatedCloseables.add(snapshot); return snapshot; } }).when(keyedStateBackend.db).createColumnFamily(any(ColumnFamilyDescriptor.class)); for (int i = 0; i < 100; ++i) { keyedStateBackend.setCurrentKey(i); testState1.update(4200 + i); testState2.update("S-" + (4200 + i)); } } @Test public void testRunningSnapshotAfterBackendClosed() throws Exception { setupRocksKeyedStateBackend(); RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint()); RocksDB spyDB = keyedStateBackend.db; if (!enableIncrementalCheckpointing) { verify(spyDB, times(1)).getSnapshot(); verify(spyDB, times(0)).releaseSnapshot(any(Snapshot.class)); } this.keyedStateBackend.dispose(); verify(spyDB, times(1)).close(); assertEquals(null, keyedStateBackend.db); //Ensure every RocksObjects not closed yet for (RocksObject rocksCloseable : allCreatedCloseables) { verify(rocksCloseable, times(0)).close(); } Thread asyncSnapshotThread = new Thread(snapshot); asyncSnapshotThread.start(); try { snapshot.get(); fail(); } catch (Exception ignored) { } asyncSnapshotThread.join(); //Ensure every RocksObject was closed exactly once for (RocksObject rocksCloseable : allCreatedCloseables) { verify(rocksCloseable, times(1)).close(); } } @Test public void testReleasingSnapshotAfterBackendClosed() throws Exception { setupRocksKeyedStateBackend(); RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint()); RocksDB spyDB = keyedStateBackend.db; if (!enableIncrementalCheckpointing) { verify(spyDB, times(1)).getSnapshot(); verify(spyDB, times(0)).releaseSnapshot(any(Snapshot.class)); } this.keyedStateBackend.dispose(); verify(spyDB, times(1)).close(); assertEquals(null, keyedStateBackend.db); //Ensure every RocksObjects not closed yet for (RocksObject rocksCloseable : allCreatedCloseables) { verify(rocksCloseable, times(0)).close(); } snapshot.cancel(true); //Ensure every RocksObjects was closed exactly once for (RocksObject rocksCloseable : allCreatedCloseables) { verify(rocksCloseable, times(1)).close(); } } @Test public void testDismissingSnapshot() throws Exception { setupRocksKeyedStateBackend(); RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint()); snapshot.cancel(true); verifyRocksObjectsReleased(); } @Test public void testDismissingSnapshotNotRunnable() throws Exception { setupRocksKeyedStateBackend(); RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint()); snapshot.cancel(true); Thread asyncSnapshotThread = new Thread(snapshot); asyncSnapshotThread.start(); try { snapshot.get(); fail(); } catch (Exception ignored) { } asyncSnapshotThread.join(); verifyRocksObjectsReleased(); } @Test public void testCompletingSnapshot() throws Exception { setupRocksKeyedStateBackend(); RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint()); Thread asyncSnapshotThread = new Thread(snapshot); asyncSnapshotThread.start(); waiter.await(); // wait for snapshot to run waiter.reset(); runStateUpdates(); blocker.trigger(); // allow checkpointing to start writing waiter.await(); // wait for snapshot stream writing to run KeyedStateHandle keyedStateHandle = snapshot.get(); assertNotNull(keyedStateHandle); assertTrue(keyedStateHandle.getStateSize() > 0); assertEquals(2, keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups()); assertTrue(testStreamFactory.getLastCreatedStream().isClosed()); asyncSnapshotThread.join(); verifyRocksObjectsReleased(); } @Test public void testCancelRunningSnapshot() throws Exception { setupRocksKeyedStateBackend(); RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint()); Thread asyncSnapshotThread = new Thread(snapshot); asyncSnapshotThread.start(); waiter.await(); // wait for snapshot to run waiter.reset(); runStateUpdates(); snapshot.cancel(true); blocker.trigger(); // allow checkpointing to start writing assertTrue(testStreamFactory.getLastCreatedStream().isClosed()); waiter.await(); // wait for snapshot stream writing to run try { snapshot.get(); fail(); } catch (Exception ignored) { } verifyRocksObjectsReleased(); asyncSnapshotThread.join(); } @Test public void testDisposeDeletesAllDirectories() throws Exception { AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE); ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", String.class, null); kvId.initializeSerializerUnlessSet(new ExecutionConfig()); ValueState<String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); backend.setCurrentKey(1); state.update("Hello"); Collection<File> allFilesInDbDir = FileUtils.listFilesAndDirs(new File(dbPath), new AcceptAllFilter(), new AcceptAllFilter()); // more than just the root directory assertTrue(allFilesInDbDir.size() > 1); backend.dispose(); allFilesInDbDir = FileUtils.listFilesAndDirs(new File(dbPath), new AcceptAllFilter(), new AcceptAllFilter()); // just the root directory left assertEquals(1, allFilesInDbDir.size()); } @Test public void testSharedIncrementalStateDeRegistration() throws Exception { if (enableIncrementalCheckpointing) { AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE); ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", String.class, null); kvId.initializeSerializerUnlessSet(new ExecutionConfig()); ValueState<String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); Queue<IncrementalKeyedStateHandle> previousStateHandles = new LinkedList<>(); SharedStateRegistry sharedStateRegistry = spy(new SharedStateRegistry()); for (int checkpointId = 0; checkpointId < 3; ++checkpointId) { reset(sharedStateRegistry); backend.setCurrentKey(checkpointId); state.update("Hello-" + checkpointId); RunnableFuture<KeyedStateHandle> snapshot = backend.snapshot( checkpointId, checkpointId, createStreamFactory(), CheckpointOptions.forFullCheckpoint()); snapshot.run(); IncrementalKeyedStateHandle stateHandle = (IncrementalKeyedStateHandle) snapshot.get(); Map<StateHandleID, StreamStateHandle> sharedState = new HashMap<>(stateHandle.getSharedState()); stateHandle.registerSharedStates(sharedStateRegistry); for (Map.Entry<StateHandleID, StreamStateHandle> e : sharedState.entrySet()) { verify(sharedStateRegistry).registerReference( stateHandle.createSharedStateRegistryKeyFromFileName(e.getKey()), e.getValue()); } previousStateHandles.add(stateHandle); backend.notifyCheckpointComplete(checkpointId); //----------------------------------------------------------------- if (previousStateHandles.size() > 1) { checkRemove(previousStateHandles.remove(), sharedStateRegistry); } } while (!previousStateHandles.isEmpty()) { reset(sharedStateRegistry); checkRemove(previousStateHandles.remove(), sharedStateRegistry); } backend.close(); backend.dispose(); } } private void checkRemove(IncrementalKeyedStateHandle remove, SharedStateRegistry registry) throws Exception { for (StateHandleID id : remove.getSharedState().keySet()) { verify(registry, times(0)).unregisterReference( remove.createSharedStateRegistryKeyFromFileName(id)); } remove.discardState(); for (StateHandleID id : remove.getSharedState().keySet()) { verify(registry).unregisterReference( remove.createSharedStateRegistryKeyFromFileName(id)); } } private void runStateUpdates() throws Exception{ for (int i = 50; i < 150; ++i) { if (i % 10 == 0) { Thread.sleep(1); } keyedStateBackend.setCurrentKey(i); testState1.update(4200 + i); testState2.update("S-" + (4200 + i)); } } private void verifyRocksObjectsReleased() { //Ensure every RocksObject was closed exactly once for (RocksObject rocksCloseable : allCreatedCloseables) { verify(rocksCloseable, times(1)).close(); } assertNotNull(null, keyedStateBackend.db); RocksDB spyDB = keyedStateBackend.db; if (!enableIncrementalCheckpointing) { verify(spyDB, times(1)).getSnapshot(); verify(spyDB, times(1)).releaseSnapshot(any(Snapshot.class)); } keyedStateBackend.dispose(); verify(spyDB, times(1)).close(); assertEquals(null, keyedStateBackend.db); } private static class AcceptAllFilter implements IOFileFilter { @Override public boolean accept(File file) { return true; } @Override public boolean accept(File file, String s) { return true; } } }