/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.api.operators;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
import org.junit.Assert;
import org.junit.Test;
import java.io.InputStream;
import java.util.BitSet;
public class StreamOperatorSnapshotRestoreTest {
private static final int MAX_PARALLELISM = 10;
@Test
public void testOperatorStatesSnapshotRestore() throws Exception {
//-------------------------------------------------------------------------- snapshot
TestOneInputStreamOperator op = new TestOneInputStreamOperator(false);
KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
op,
new KeySelector<Integer, Integer>() {
@Override
public Integer getKey(Integer value) throws Exception {
return value;
}
},
TypeInformation.of(Integer.class),
MAX_PARALLELISM,
1 /* num subtasks */,
0 /* subtask index */);
testHarness.open();
for (int i = 0; i < 10; ++i) {
testHarness.processElement(new StreamRecord<>(i));
}
OperatorStateHandles handles = testHarness.snapshot(1L, 1L);
testHarness.close();
//-------------------------------------------------------------------------- restore
op = new TestOneInputStreamOperator(true);
testHarness = new KeyedOneInputStreamOperatorTestHarness<>(
op,
new KeySelector<Integer, Integer>() {
@Override
public Integer getKey(Integer value) throws Exception {
return value;
}
},
TypeInformation.of(Integer.class),
MAX_PARALLELISM,
1 /* num subtasks */,
0 /* subtask index */);
testHarness.initializeState(handles);
testHarness.open();
for (int i = 0; i < 10; ++i) {
testHarness.processElement(new StreamRecord<>(i));
}
testHarness.close();
}
static class TestOneInputStreamOperator
extends AbstractStreamOperator<Integer>
implements OneInputStreamOperator<Integer, Integer> {
private static final long serialVersionUID = -8942866418598856475L;
public TestOneInputStreamOperator(boolean verifyRestore) {
this.verifyRestore = verifyRestore;
}
private boolean verifyRestore;
private ValueState<Integer> keyedState;
private ListState<Integer> opState;
@Override
public void processElement(StreamRecord<Integer> element) throws Exception {
if (verifyRestore) {
// check restored managed keyed state
long exp = element.getValue() + 1;
long act = keyedState.value();
Assert.assertEquals(exp, act);
} else {
// write managed keyed state that goes into snapshot
keyedState.update(element.getValue() + 1);
// write managed operator state that goes into snapshot
opState.add(element.getValue());
}
}
@Override
public void processWatermark(Watermark mark) throws Exception {
}
@Override
public void snapshotState(StateSnapshotContext context) throws Exception {
KeyedStateCheckpointOutputStream out = context.getRawKeyedOperatorStateOutput();
DataOutputView dov = new DataOutputViewStreamWrapper(out);
// write raw keyed state that goes into snapshot
int count = 0;
for (int kg : out.getKeyGroupList()) {
out.startNewKeyGroup(kg);
dov.writeInt(kg + 2);
++count;
}
Assert.assertEquals(MAX_PARALLELISM, count);
// write raw operator state that goes into snapshot
OperatorStateCheckpointOutputStream outOp = context.getRawOperatorStateOutput();
dov = new DataOutputViewStreamWrapper(outOp);
for (int i = 0; i < 13; ++i) {
outOp.startNewPartition();
dov.writeInt(42 + i);
}
}
@Override
public void initializeState(StateInitializationContext context) throws Exception {
Assert.assertEquals(verifyRestore, context.isRestored());
keyedState = context
.getKeyedStateStore()
.getState(new ValueStateDescriptor<>("managed-keyed", Integer.class, 0));
opState = context
.getOperatorStateStore()
.getListState(new ListStateDescriptor<>("managed-op-state", IntSerializer.INSTANCE));
if (context.isRestored()) {
// check restored raw keyed state
int count = 0;
for (KeyGroupStatePartitionStreamProvider streamProvider : context.getRawKeyedStateInputs()) {
try (InputStream in = streamProvider.getStream()) {
DataInputView div = new DataInputViewStreamWrapper(in);
Assert.assertEquals(streamProvider.getKeyGroupId() + 2, div.readInt());
++count;
}
}
Assert.assertEquals(MAX_PARALLELISM, count);
// check restored managed operator state
BitSet check = new BitSet(10);
for (int v : opState.get()) {
check.set(v);
}
Assert.assertEquals(10, check.cardinality());
// check restored raw operator state
check = new BitSet(13);
for (StatePartitionStreamProvider streamProvider : context.getRawOperatorStateInputs()) {
try (InputStream in = streamProvider.getStream()) {
DataInputView div = new DataInputViewStreamWrapper(in);
check.set(div.readInt() - 42);
}
}
Assert.assertEquals(13, check.cardinality());
}
}
}
}