/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.state; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.TypeSerializerSerializationProxy; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.state.DefaultOperatorStateBackend.PartitionableListState; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory; import org.apache.flink.util.FutureUtil; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; import java.io.File; import java.io.IOException; import java.io.Serializable; import java.util.Collections; import java.util.Iterator; import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.FutureTask; import java.util.concurrent.RunnableFuture; import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(PowerMockRunner.class) @PrepareForTest(OperatorBackendStateMetaInfoSnapshotReaderWriters.class) public class OperatorStateBackendTest { private final ClassLoader classLoader = getClass().getClassLoader(); @Test public void testCreateOnAbstractStateBackend() throws Exception { // we use the memory state backend as a subclass of the AbstractStateBackend final AbstractStateBackend abstractStateBackend = new MemoryStateBackend(); final OperatorStateBackend operatorStateBackend = abstractStateBackend.createOperatorStateBackend( createMockEnvironment(), "test-operator"); assertNotNull(operatorStateBackend); assertTrue(operatorStateBackend.getRegisteredStateNames().isEmpty()); } @Test public void testRegisterStatesWithoutTypeSerializer() throws Exception { // prepare an execution config with a non standard type registered final Class<?> registeredType = FutureTask.class; // validate the precondition of this test - if this condition fails, we need to pick a different // example serializer assertFalse(new KryoSerializer<>(File.class, new ExecutionConfig()).getKryo().getDefaultSerializer(registeredType) instanceof com.esotericsoftware.kryo.serializers.JavaSerializer); final ExecutionConfig cfg = new ExecutionConfig(); cfg.registerTypeWithKryoSerializer(registeredType, com.esotericsoftware.kryo.serializers.JavaSerializer.class); final OperatorStateBackend operatorStateBackend = new DefaultOperatorStateBackend(classLoader, cfg, false); ListStateDescriptor<File> stateDescriptor = new ListStateDescriptor<>("test", File.class); ListStateDescriptor<String> stateDescriptor2 = new ListStateDescriptor<>("test2", String.class); ListState<File> listState = operatorStateBackend.getListState(stateDescriptor); assertNotNull(listState); ListState<String> listState2 = operatorStateBackend.getListState(stateDescriptor2); assertNotNull(listState2); assertEquals(2, operatorStateBackend.getRegisteredStateNames().size()); // make sure that type registrations are forwarded TypeSerializer<?> serializer = ((PartitionableListState<?>) listState).getStateMetaInfo().getPartitionStateSerializer(); assertTrue(serializer instanceof KryoSerializer); assertTrue(((KryoSerializer<?>) serializer).getKryo().getSerializer(registeredType) instanceof com.esotericsoftware.kryo.serializers.JavaSerializer); Iterator<String> it = listState2.get().iterator(); assertFalse(it.hasNext()); listState2.add("kevin"); listState2.add("sunny"); it = listState2.get().iterator(); assertEquals("kevin", it.next()); assertEquals("sunny", it.next()); assertFalse(it.hasNext()); } @Test public void testRegisterStates() throws Exception { final OperatorStateBackend operatorStateBackend = new DefaultOperatorStateBackend(classLoader, new ExecutionConfig(), false); ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>()); ListState<Serializable> listState1 = operatorStateBackend.getListState(stateDescriptor1); assertNotNull(listState1); assertEquals(1, operatorStateBackend.getRegisteredStateNames().size()); Iterator<Serializable> it = listState1.get().iterator(); assertFalse(it.hasNext()); listState1.add(42); listState1.add(4711); it = listState1.get().iterator(); assertEquals(42, it.next()); assertEquals(4711, it.next()); assertFalse(it.hasNext()); ListState<Serializable> listState2 = operatorStateBackend.getListState(stateDescriptor2); assertNotNull(listState2); assertEquals(2, operatorStateBackend.getRegisteredStateNames().size()); assertFalse(it.hasNext()); listState2.add(7); listState2.add(13); listState2.add(23); it = listState2.get().iterator(); assertEquals(7, it.next()); assertEquals(13, it.next()); assertEquals(23, it.next()); assertFalse(it.hasNext()); ListState<Serializable> listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); assertNotNull(listState3); assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); assertFalse(it.hasNext()); listState3.add(17); listState3.add(3); listState3.add(123); it = listState3.get().iterator(); assertEquals(17, it.next()); assertEquals(3, it.next()); assertEquals(123, it.next()); assertFalse(it.hasNext()); ListState<Serializable> listState1b = operatorStateBackend.getListState(stateDescriptor1); assertNotNull(listState1b); listState1b.add(123); it = listState1b.get().iterator(); assertEquals(42, it.next()); assertEquals(4711, it.next()); assertEquals(123, it.next()); assertFalse(it.hasNext()); it = listState1.get().iterator(); assertEquals(42, it.next()); assertEquals(4711, it.next()); assertEquals(123, it.next()); assertFalse(it.hasNext()); it = listState1b.get().iterator(); assertEquals(42, it.next()); assertEquals(4711, it.next()); assertEquals(123, it.next()); assertFalse(it.hasNext()); try { operatorStateBackend.getUnionListState(stateDescriptor2); fail("Did not detect changed mode"); } catch (IllegalStateException ignored) { } try { operatorStateBackend.getListState(stateDescriptor3); fail("Did not detect changed mode"); } catch (IllegalStateException ignored) { } } @Test public void testSnapshotEmpty() throws Exception { final AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096); final OperatorStateBackend operatorStateBackend = abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), "testOperator"); CheckpointStreamFactory streamFactory = abstractStateBackend.createStreamFactory(new JobID(), "testOperator"); RunnableFuture<OperatorStateHandle> snapshot = operatorStateBackend.snapshot(0L, 0L, streamFactory, CheckpointOptions.forFullCheckpoint()); OperatorStateHandle stateHandle = FutureUtil.runIfNotDoneAndGet(snapshot); assertNull(stateHandle); } @Test public void testSnapshotRestoreSync() throws Exception { AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096); OperatorStateBackend operatorStateBackend = abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), "test-op-name"); ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>()); ListState<Serializable> listState1 = operatorStateBackend.getListState(stateDescriptor1); ListState<Serializable> listState2 = operatorStateBackend.getListState(stateDescriptor2); ListState<Serializable> listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); listState1.add(42); listState1.add(4711); listState2.add(7); listState2.add(13); listState2.add(23); listState3.add(17); listState3.add(18); listState3.add(19); listState3.add(20); CheckpointStreamFactory streamFactory = abstractStateBackend.createStreamFactory(new JobID(), "testOperator"); RunnableFuture<OperatorStateHandle> runnableFuture = operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forFullCheckpoint()); OperatorStateHandle stateHandle = FutureUtil.runIfNotDoneAndGet(runnableFuture); try { operatorStateBackend.close(); operatorStateBackend.dispose(); operatorStateBackend = abstractStateBackend.createOperatorStateBackend( createMockEnvironment(), "testOperator"); operatorStateBackend.restore(Collections.singletonList(stateHandle)); assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); listState1 = operatorStateBackend.getListState(stateDescriptor1); listState2 = operatorStateBackend.getListState(stateDescriptor2); listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); Iterator<Serializable> it = listState1.get().iterator(); assertEquals(42, it.next()); assertEquals(4711, it.next()); assertFalse(it.hasNext()); it = listState2.get().iterator(); assertEquals(7, it.next()); assertEquals(13, it.next()); assertEquals(23, it.next()); assertFalse(it.hasNext()); it = listState3.get().iterator(); assertEquals(17, it.next()); assertEquals(18, it.next()); assertEquals(19, it.next()); assertEquals(20, it.next()); assertFalse(it.hasNext()); operatorStateBackend.close(); operatorStateBackend.dispose(); } finally { stateHandle.discardState(); } } @Test public void testSnapshotRestoreAsync() throws Exception { OperatorStateBackend operatorStateBackend = new DefaultOperatorStateBackend(OperatorStateBackendTest.class.getClassLoader(), new ExecutionConfig(), true); ListStateDescriptor<MutableType> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<MutableType>()); ListStateDescriptor<MutableType> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<MutableType>()); ListStateDescriptor<MutableType> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<MutableType>()); ListState<MutableType> listState1 = operatorStateBackend.getListState(stateDescriptor1); ListState<MutableType> listState2 = operatorStateBackend.getListState(stateDescriptor2); ListState<MutableType> listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); listState1.add(MutableType.of(42)); listState1.add(MutableType.of(4711)); listState2.add(MutableType.of(7)); listState2.add(MutableType.of(13)); listState2.add(MutableType.of(23)); listState3.add(MutableType.of(17)); listState3.add(MutableType.of(18)); listState3.add(MutableType.of(19)); listState3.add(MutableType.of(20)); BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024); OneShotLatch waiterLatch = new OneShotLatch(); OneShotLatch blockerLatch = new OneShotLatch(); streamFactory.setWaiterLatch(waiterLatch); streamFactory.setBlockerLatch(blockerLatch); RunnableFuture<OperatorStateHandle> runnableFuture = operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forFullCheckpoint()); ExecutorService executorService = Executors.newFixedThreadPool(1); executorService.submit(runnableFuture); // wait until the async checkpoint is in the write code, then continue waiterLatch.await(); // do some mutations to the state, to test if our snapshot will NOT reflect them listState1.add(MutableType.of(77)); int n = 0; for (MutableType mutableType : listState2.get()) { if (++n == 2) { // allow the write code to continue, so that we could do changes while state is written in parallel. blockerLatch.trigger(); } mutableType.setValue(mutableType.getValue() + 10); } listState3.clear(); operatorStateBackend.getListState( new ListStateDescriptor<>("test4", new JavaSerializer<MutableType>())); // run the snapshot OperatorStateHandle stateHandle = runnableFuture.get(); try { operatorStateBackend.close(); operatorStateBackend.dispose(); AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096); operatorStateBackend = abstractStateBackend.createOperatorStateBackend( createMockEnvironment(), "testOperator"); operatorStateBackend.restore(Collections.singletonList(stateHandle)); assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); listState1 = operatorStateBackend.getListState(stateDescriptor1); listState2 = operatorStateBackend.getListState(stateDescriptor2); listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); Iterator<MutableType> it = listState1.get().iterator(); assertEquals(42, it.next().value); assertEquals(4711, it.next().value); assertFalse(it.hasNext()); it = listState2.get().iterator(); assertEquals(7, it.next().value); assertEquals(13, it.next().value); assertEquals(23, it.next().value); assertFalse(it.hasNext()); it = listState3.get().iterator(); assertEquals(17, it.next().value); assertEquals(18, it.next().value); assertEquals(19, it.next().value); assertEquals(20, it.next().value); assertFalse(it.hasNext()); operatorStateBackend.close(); operatorStateBackend.dispose(); } finally { stateHandle.discardState(); } executorService.shutdown(); } @Test public void testSnapshotAsyncClose() throws Exception { DefaultOperatorStateBackend operatorStateBackend = new DefaultOperatorStateBackend(OperatorStateBackendTest.class.getClassLoader(), new ExecutionConfig(), true); ListStateDescriptor<MutableType> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<MutableType>()); ListState<MutableType> listState1 = operatorStateBackend.getOperatorState(stateDescriptor1); listState1.add(MutableType.of(42)); listState1.add(MutableType.of(4711)); BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024); OneShotLatch waiterLatch = new OneShotLatch(); OneShotLatch blockerLatch = new OneShotLatch(); streamFactory.setWaiterLatch(waiterLatch); streamFactory.setBlockerLatch(blockerLatch); RunnableFuture<OperatorStateHandle> runnableFuture = operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forFullCheckpoint()); ExecutorService executorService = Executors.newFixedThreadPool(1); executorService.submit(runnableFuture); // wait until the async checkpoint is in the write code, then continue waiterLatch.await(); operatorStateBackend.close(); blockerLatch.trigger(); try { runnableFuture.get(60, TimeUnit.SECONDS); Assert.fail(); } catch (ExecutionException eex) { Assert.assertTrue(eex.getCause() instanceof IOException); } } @Test public void testSnapshotAsyncCancel() throws Exception { DefaultOperatorStateBackend operatorStateBackend = new DefaultOperatorStateBackend(OperatorStateBackendTest.class.getClassLoader(), new ExecutionConfig(), true); ListStateDescriptor<MutableType> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<MutableType>()); ListState<MutableType> listState1 = operatorStateBackend.getOperatorState(stateDescriptor1); listState1.add(MutableType.of(42)); listState1.add(MutableType.of(4711)); BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024); OneShotLatch waiterLatch = new OneShotLatch(); OneShotLatch blockerLatch = new OneShotLatch(); streamFactory.setWaiterLatch(waiterLatch); streamFactory.setBlockerLatch(blockerLatch); RunnableFuture<OperatorStateHandle> runnableFuture = operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forFullCheckpoint()); ExecutorService executorService = Executors.newFixedThreadPool(1); executorService.submit(runnableFuture); // wait until the async checkpoint is in the stream's write code, then continue waiterLatch.await(); // cancel the future, which should close the underlying stream runnableFuture.cancel(true); Assert.assertTrue(streamFactory.getLastCreatedStream().isClosed()); // we allow the stream under test to proceed blockerLatch.trigger(); try { runnableFuture.get(60, TimeUnit.SECONDS); Assert.fail(); } catch (CancellationException ignore) { } } @Test public void testRestoreFailsIfSerializerDeserializationFails() throws Exception { AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096); OperatorStateBackend operatorStateBackend = abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), "test-op-name"); // write some state ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>()); ListState<Serializable> listState1 = operatorStateBackend.getListState(stateDescriptor1); ListState<Serializable> listState2 = operatorStateBackend.getListState(stateDescriptor2); ListState<Serializable> listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); listState1.add(42); listState1.add(4711); listState2.add(7); listState2.add(13); listState2.add(23); listState3.add(17); listState3.add(18); listState3.add(19); listState3.add(20); CheckpointStreamFactory streamFactory = abstractStateBackend.createStreamFactory(new JobID(), "testOperator"); RunnableFuture<OperatorStateHandle> runnableFuture = operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forFullCheckpoint()); OperatorStateHandle stateHandle = FutureUtil.runIfNotDoneAndGet(runnableFuture); try { operatorStateBackend.close(); operatorStateBackend.dispose(); operatorStateBackend = abstractStateBackend.createOperatorStateBackend( createMockEnvironment(), "testOperator"); // mock failure when deserializing serializer TypeSerializerSerializationProxy<?> mockProxy = mock(TypeSerializerSerializationProxy.class); doThrow(new IOException()).when(mockProxy).read(any(DataInputViewStreamWrapper.class)); PowerMockito.whenNew(TypeSerializerSerializationProxy.class).withAnyArguments().thenReturn(mockProxy); operatorStateBackend.restore(Collections.singletonList(stateHandle)); fail("The operator state restore should have failed if the previous state serializer could not be loaded."); } catch (IOException expected) { Assert.assertTrue(expected.getMessage().contains("Unable to restore operator state")); } finally { stateHandle.discardState(); } } static final class MutableType implements Serializable { private static final long serialVersionUID = 1L; private int value; public MutableType() { this(0); } public MutableType(int value) { this.value = value; } public int getValue() { return value; } public void setValue(int value) { this.value = value; } @Override public boolean equals(Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } MutableType that = (MutableType) o; return value == that.value; } @Override public int hashCode() { return value; } static MutableType of(int value) { return new MutableType(value); } } // ------------------------------------------------------------------------ // utilities // ------------------------------------------------------------------------ private static Environment createMockEnvironment() { Environment env = mock(Environment.class); when(env.getExecutionConfig()).thenReturn(new ExecutionConfig()); when(env.getUserClassLoader()).thenReturn(OperatorStateBackendTest.class.getClassLoader()); return env; } }