/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.api.operators;
import org.apache.flink.api.common.state.KeyedStateStore;
import org.apache.flink.api.common.state.OperatorStateStore;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.StateInitializationContextImpl;
import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.runtime.util.LongArrayList;
import org.apache.flink.util.Preconditions;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.mockito.Mockito.mock;
public class StateInitializationContextImplTest {
static final int NUM_HANDLES = 10;
private StateInitializationContextImpl initializationContext;
private CloseableRegistry closableRegistry;
private int writtenKeyGroups;
private Set<Integer> writtenOperatorStates;
@Before
public void setUp() throws Exception {
this.writtenKeyGroups = 0;
this.writtenOperatorStates = new HashSet<>();
this.closableRegistry = new CloseableRegistry();
OperatorStateStore stateStore = mock(OperatorStateStore.class);
ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos(64);
List<KeyedStateHandle> keyedStateHandles = new ArrayList<>(NUM_HANDLES);
int prev = 0;
for (int i = 0; i < NUM_HANDLES; ++i) {
out.reset();
int size = i % 4;
int end = prev + size;
DataOutputView dov = new DataOutputViewStreamWrapper(out);
KeyGroupRangeOffsets offsets =
new KeyGroupRangeOffsets(i == 9 ? KeyGroupRange.EMPTY_KEY_GROUP_RANGE : new KeyGroupRange(prev, end));
prev = end + 1;
for (int kg : offsets.getKeyGroupRange()) {
offsets.setKeyGroupOffset(kg, out.getPosition());
dov.writeInt(kg);
++writtenKeyGroups;
}
KeyedStateHandle handle =
new KeyGroupsStateHandle(offsets, new ByteStateHandleCloseChecking("kg-" + i, out.toByteArray()));
keyedStateHandles.add(handle);
}
List<OperatorStateHandle> operatorStateHandles = new ArrayList<>(NUM_HANDLES);
for (int i = 0; i < NUM_HANDLES; ++i) {
int size = i % 4;
out.reset();
DataOutputView dov = new DataOutputViewStreamWrapper(out);
LongArrayList offsets = new LongArrayList(size);
for (int s = 0; s < size; ++s) {
offsets.add(out.getPosition());
int val = i * NUM_HANDLES + s;
dov.writeInt(val);
writtenOperatorStates.add(val);
}
Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>();
offsetsMap.put(
DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME,
new OperatorStateHandle.StateMetaInfo(offsets.toArray(), OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
OperatorStateHandle operatorStateHandle =
new OperatorStateHandle(offsetsMap, new ByteStateHandleCloseChecking("os-" + i, out.toByteArray()));
operatorStateHandles.add(operatorStateHandle);
}
this.initializationContext =
new StateInitializationContextImpl(
true,
stateStore,
mock(KeyedStateStore.class),
keyedStateHandles,
operatorStateHandles,
closableRegistry);
}
@Test
public void getOperatorStateStreams() throws Exception {
int i = 0;
int s = 0;
for (StatePartitionStreamProvider streamProvider : initializationContext.getRawOperatorStateInputs()) {
if (0 == i % 4) {
++i;
}
Assert.assertNotNull(streamProvider);
try (InputStream is = streamProvider.getStream()) {
DataInputView div = new DataInputViewStreamWrapper(is);
int val = div.readInt();
Assert.assertEquals(i * NUM_HANDLES + s, val);
}
++s;
if (s == i % 4) {
s = 0;
++i;
}
}
}
@Test
public void getKeyedStateStreams() throws Exception {
int readKeyGroupCount = 0;
for (KeyGroupStatePartitionStreamProvider stateStreamProvider
: initializationContext.getRawKeyedStateInputs()) {
Assert.assertNotNull(stateStreamProvider);
try (InputStream is = stateStreamProvider.getStream()) {
DataInputView div = new DataInputViewStreamWrapper(is);
int val = div.readInt();
++readKeyGroupCount;
Assert.assertEquals(stateStreamProvider.getKeyGroupId(), val);
}
}
Assert.assertEquals(writtenKeyGroups, readKeyGroupCount);
}
@Test
public void getOperatorStateStore() throws Exception {
Set<Integer> readStatesCount = new HashSet<>();
for (StatePartitionStreamProvider statePartitionStreamProvider
: initializationContext.getRawOperatorStateInputs()) {
Assert.assertNotNull(statePartitionStreamProvider);
try (InputStream is = statePartitionStreamProvider.getStream()) {
DataInputView div = new DataInputViewStreamWrapper(is);
Assert.assertTrue(readStatesCount.add(div.readInt()));
}
}
Assert.assertEquals(writtenOperatorStates, readStatesCount);
}
@Test
public void close() throws Exception {
int count = 0;
int stopCount = NUM_HANDLES / 2;
boolean isClosed = false;
try {
for (KeyGroupStatePartitionStreamProvider stateStreamProvider
: initializationContext.getRawKeyedStateInputs()) {
Assert.assertNotNull(stateStreamProvider);
if (count == stopCount) {
initializationContext.close();
isClosed = true;
}
try (InputStream is = stateStreamProvider.getStream()) {
DataInputView div = new DataInputViewStreamWrapper(is);
try {
int val = div.readInt();
Assert.assertEquals(stateStreamProvider.getKeyGroupId(), val);
if (isClosed) {
Assert.fail("Close was ignored: stream");
}
++count;
} catch (IOException ioex) {
if (!isClosed) {
throw ioex;
}
}
}
}
Assert.fail("Close was ignored: registry");
} catch (IOException iex) {
Assert.assertTrue(isClosed);
Assert.assertEquals(stopCount, count);
}
}
static final class ByteStateHandleCloseChecking extends ByteStreamStateHandle {
private static final long serialVersionUID = -6201941296931334140L;
public ByteStateHandleCloseChecking(String handleName, byte[] data) {
super(handleName, data);
}
@Override
public FSDataInputStream openInputStream() throws IOException {
return new FSDataInputStream() {
private int index = 0;
private boolean closed = false;
@Override
public void seek(long desired) throws IOException {
Preconditions.checkArgument(desired >= 0 && desired < Integer.MAX_VALUE);
index = (int) desired;
}
@Override
public long getPos() throws IOException {
return index;
}
@Override
public int read() throws IOException {
if (closed) {
throw new IOException("Stream closed");
}
return index < data.length ? data[index++] & 0xFF : -1;
}
@Override
public void close() throws IOException {
super.close();
this.closed = true;
}
};
}
}
}