/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.flink.streaming;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.nio.ByteBuffer;
import java.util.Arrays;
import org.apache.beam.runners.core.StateMerging;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateNamespaceForTest;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.runners.core.StateTags;
import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkKeyGroupStateInternals;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.java.typeutils.GenericTypeInfo;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
import org.apache.flink.runtime.query.KvStateRegistry;
import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.streaming.api.operators.KeyContext;
import org.hamcrest.Matchers;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/**
* Tests for {@link FlinkKeyGroupStateInternals}. This is based on the tests for
* {@code InMemoryStateInternals}.
*/
@RunWith(JUnit4.class)
public class FlinkKeyGroupStateInternalsTest {
private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1");
private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2");
private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3");
private static final StateTag<BagState<String>> STRING_BAG_ADDR =
StateTags.bag("stringBag", StringUtf8Coder.of());
FlinkKeyGroupStateInternals<String> underTest;
private KeyedStateBackend keyedStateBackend;
@Before
public void initStateInternals() {
try {
keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1));
underTest = new FlinkKeyGroupStateInternals<>(StringUtf8Coder.of(), keyedStateBackend);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private KeyedStateBackend getKeyedStateBackend(int numberOfKeyGroups,
KeyGroupRange keyGroupRange) {
MemoryStateBackend backend = new MemoryStateBackend();
try {
AbstractKeyedStateBackend<ByteBuffer> keyedStateBackend = backend.createKeyedStateBackend(
new DummyEnvironment("test", 1, 0),
new JobID(),
"test_op",
new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig()),
numberOfKeyGroups,
keyGroupRange,
new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
keyedStateBackend.setCurrentKey(ByteBuffer.wrap(
CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "1")));
return keyedStateBackend;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Test
public void testBag() throws Exception {
BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
assertThat(value.read(), Matchers.emptyIterable());
value.add("hello");
assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
value.add("world");
assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
value.clear();
assertThat(value.read(), Matchers.emptyIterable());
assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
}
@Test
public void testBagIsEmpty() throws Exception {
BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
assertThat(value.isEmpty().read(), Matchers.is(true));
ReadableState<Boolean> readFuture = value.isEmpty();
value.add("hello");
assertThat(readFuture.read(), Matchers.is(false));
value.clear();
assertThat(readFuture.read(), Matchers.is(true));
}
@Test
public void testMergeBagIntoSource() throws Exception {
BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
bag1.add("Hello");
bag2.add("World");
bag1.add("!");
StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1);
// Reading the merged bag gets both the contents
assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
assertThat(bag2.read(), Matchers.emptyIterable());
}
@Test
public void testMergeBagIntoNewNamespace() throws Exception {
BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR);
bag1.add("Hello");
bag2.add("World");
bag1.add("!");
StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3);
// Reading the merged bag gets both the contents
assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
assertThat(bag1.read(), Matchers.emptyIterable());
assertThat(bag2.read(), Matchers.emptyIterable());
}
@Test
public void testKeyGroupAndCheckpoint() throws Exception {
// assign to keyGroup 0
ByteBuffer key0 = ByteBuffer.wrap(
CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "11111111"));
// assign to keyGroup 1
ByteBuffer key1 = ByteBuffer.wrap(
CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "22222222"));
FlinkKeyGroupStateInternals<String> allState;
{
KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1));
allState = new FlinkKeyGroupStateInternals<>(
StringUtf8Coder.of(), keyedStateBackend);
BagState<String> valueForNamespace0 = allState.state(NAMESPACE_1, STRING_BAG_ADDR);
BagState<String> valueForNamespace1 = allState.state(NAMESPACE_2, STRING_BAG_ADDR);
keyedStateBackend.setCurrentKey(key0);
valueForNamespace0.add("0");
valueForNamespace1.add("2");
keyedStateBackend.setCurrentKey(key1);
valueForNamespace0.add("1");
valueForNamespace1.add("3");
assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1"));
assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3"));
}
ClassLoader classLoader = FlinkKeyGroupStateInternalsTest.class.getClassLoader();
// 1. scale up
ByteArrayOutputStream out0 = new ByteArrayOutputStream();
allState.snapshotKeyGroupState(0, new DataOutputStream(out0));
DataInputStream in0 = new DataInputStream(
new ByteArrayInputStream(out0.toByteArray()));
{
KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 0));
FlinkKeyGroupStateInternals<String> state0 =
new FlinkKeyGroupStateInternals<>(
StringUtf8Coder.of(), keyedStateBackend);
state0.restoreKeyGroupState(0, in0, classLoader);
BagState<String> valueForNamespace0 = state0.state(NAMESPACE_1, STRING_BAG_ADDR);
BagState<String> valueForNamespace1 = state0.state(NAMESPACE_2, STRING_BAG_ADDR);
assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0"));
assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2"));
}
ByteArrayOutputStream out1 = new ByteArrayOutputStream();
allState.snapshotKeyGroupState(1, new DataOutputStream(out1));
DataInputStream in1 = new DataInputStream(
new ByteArrayInputStream(out1.toByteArray()));
{
KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(1, 1));
FlinkKeyGroupStateInternals<String> state1 =
new FlinkKeyGroupStateInternals<>(
StringUtf8Coder.of(), keyedStateBackend);
state1.restoreKeyGroupState(1, in1, classLoader);
BagState<String> valueForNamespace0 = state1.state(NAMESPACE_1, STRING_BAG_ADDR);
BagState<String> valueForNamespace1 = state1.state(NAMESPACE_2, STRING_BAG_ADDR);
assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("1"));
assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("3"));
}
// 2. scale down
{
KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1));
FlinkKeyGroupStateInternals<String> newAllState = new FlinkKeyGroupStateInternals<>(
StringUtf8Coder.of(), keyedStateBackend);
in0.reset();
in1.reset();
newAllState.restoreKeyGroupState(0, in0, classLoader);
newAllState.restoreKeyGroupState(1, in1, classLoader);
BagState<String> valueForNamespace0 = newAllState.state(NAMESPACE_1, STRING_BAG_ADDR);
BagState<String> valueForNamespace1 = newAllState.state(NAMESPACE_2, STRING_BAG_ADDR);
assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1"));
assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3"));
}
}
private static class TestKeyContext implements KeyContext {
private Object key;
@Override
public void setCurrentKey(Object key) {
this.key = key;
}
@Override
public Object getCurrentKey() {
return key;
}
}
}