/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.util;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.JobID;
import org.apache.flink.util.OutputTag;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.core.fs.FSDataOutputStream;
import org.apache.flink.migration.runtime.checkpoint.savepoint.SavepointV0Serializer;
import org.apache.flink.migration.streaming.runtime.tasks.StreamTaskState;
import org.apache.flink.migration.util.MigrationInstantiationUtil;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.runtime.state.AbstractStateBackend;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateBackend;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.AbstractStreamOperatorTest;
import org.apache.flink.streaming.api.operators.OperatorSnapshotResult;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.streamstatus.StreamStatus;
import org.apache.flink.streaming.runtime.streamstatus.StreamStatusMaintainer;
import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService;
import org.apache.flink.util.FutureUtil;
import org.apache.flink.util.Preconditions;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.FileInputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentLinkedQueue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Base class for {@code AbstractStreamOperator} test harnesses.
*/
public class AbstractStreamOperatorTestHarness<OUT> {
final protected StreamOperator<OUT> operator;
final protected ConcurrentLinkedQueue<Object> outputList;
final protected Map<OutputTag<?>, ConcurrentLinkedQueue<Object>> sideOutputLists;
final protected StreamConfig config;
final protected ExecutionConfig executionConfig;
final protected TestProcessingTimeService processingTimeService;
final protected StreamTask<?, ?> mockTask;
final Environment environment;
CloseableRegistry closableRegistry;
// use this as default for tests
protected AbstractStateBackend stateBackend = new MemoryStateBackend();
private final Object checkpointLock;
private final OperatorStateRepartitioner operatorStateRepartitioner =
RoundRobinOperatorStateRepartitioner.INSTANCE;
/**
* Whether setup() was called on the operator. This is reset when calling close().
*/
private boolean setupCalled = false;
private boolean initializeCalled = false;
private volatile boolean wasFailedExternally = false;
public AbstractStreamOperatorTestHarness(
StreamOperator<OUT> operator,
int maxParallelism,
int numSubtasks,
int subtaskIndex) throws Exception {
this(
operator,
maxParallelism,
numSubtasks,
subtaskIndex,
new MockEnvironment(
"MockTask",
3 * 1024 * 1024,
new MockInputSplitProvider(),
1024,
new Configuration(),
new ExecutionConfig(),
maxParallelism,
numSubtasks,
subtaskIndex));
}
public AbstractStreamOperatorTestHarness(
StreamOperator<OUT> operator,
int maxParallelism,
int numSubtasks,
int subtaskIndex,
final Environment environment) throws Exception {
this.operator = operator;
this.outputList = new ConcurrentLinkedQueue<>();
this.sideOutputLists = new HashMap<>();
Configuration underlyingConfig = environment.getTaskConfiguration();
this.config = new StreamConfig(underlyingConfig);
this.config.setCheckpointingEnabled(true);
this.executionConfig = environment.getExecutionConfig();
this.closableRegistry = new CloseableRegistry();
this.checkpointLock = new Object();
this.environment = Preconditions.checkNotNull(environment);
mockTask = mock(StreamTask.class);
processingTimeService = new TestProcessingTimeService();
processingTimeService.setCurrentTime(0);
StreamStatusMaintainer mockStreamStatusMaintainer = new StreamStatusMaintainer() {
StreamStatus currentStreamStatus = StreamStatus.ACTIVE;
@Override
public void toggleStreamStatus(StreamStatus streamStatus) {
if (!currentStreamStatus.equals(streamStatus)) {
currentStreamStatus = streamStatus;
}
}
@Override
public StreamStatus getStreamStatus() {
return currentStreamStatus;
}
};
when(mockTask.getName()).thenReturn("Mock Task");
when(mockTask.getCheckpointLock()).thenReturn(checkpointLock);
when(mockTask.getConfiguration()).thenReturn(config);
when(mockTask.getTaskConfiguration()).thenReturn(underlyingConfig);
when(mockTask.getEnvironment()).thenReturn(environment);
when(mockTask.getExecutionConfig()).thenReturn(executionConfig);
when(mockTask.getUserCodeClassLoader()).thenReturn(this.getClass().getClassLoader());
when(mockTask.getCancelables()).thenReturn(this.closableRegistry);
when(mockTask.getStreamStatusMaintainer()).thenReturn(mockStreamStatusMaintainer);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
wasFailedExternally = true;
return null;
}
}).when(mockTask).handleAsyncException(any(String.class), any(Throwable.class));
try {
doAnswer(new Answer<CheckpointStreamFactory>() {
@Override
public CheckpointStreamFactory answer(InvocationOnMock invocationOnMock) throws Throwable {
final StreamOperator<?> operator = (StreamOperator<?>) invocationOnMock.getArguments()[0];
return stateBackend.createStreamFactory(new JobID(), operator.getClass().getSimpleName());
}
}).when(mockTask).createCheckpointStreamFactory(any(StreamOperator.class));
} catch (Exception e) {
throw new RuntimeException(e.getMessage(), e);
}
try {
doAnswer(new Answer<OperatorStateBackend>() {
@Override
public OperatorStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable {
final StreamOperator<?> operator = (StreamOperator<?>) invocationOnMock.getArguments()[0];
final Collection<OperatorStateHandle> stateHandles = (Collection<OperatorStateHandle>) invocationOnMock.getArguments()[1];
OperatorStateBackend osb;
osb = stateBackend.createOperatorStateBackend(
environment,
operator.getClass().getSimpleName());
mockTask.getCancelables().registerClosable(osb);
if (null != stateHandles) {
osb.restore(stateHandles);
}
return osb;
}
}).when(mockTask).createOperatorStateBackend(any(StreamOperator.class), any(Collection.class));
} catch (Exception e) {
throw new RuntimeException(e.getMessage(), e);
}
doAnswer(new Answer<ProcessingTimeService>() {
@Override
public ProcessingTimeService answer(InvocationOnMock invocation) throws Throwable {
return processingTimeService;
}
}).when(mockTask).getProcessingTimeService();
}
public void setStateBackend(AbstractStateBackend stateBackend) {
this.stateBackend = stateBackend;
}
public Object getCheckpointLock() {
return mockTask.getCheckpointLock();
}
public Environment getEnvironment() {
return this.mockTask.getEnvironment();
}
public ExecutionConfig getExecutionConfig() {
return executionConfig;
}
/**
* Get all the output from the task. This contains StreamRecords and Events interleaved.
*/
public ConcurrentLinkedQueue<Object> getOutput() {
return outputList;
}
@SuppressWarnings({"unchecked", "rawtypes"})
public <X> ConcurrentLinkedQueue<StreamRecord<X>> getSideOutput(OutputTag<X> tag) {
return (ConcurrentLinkedQueue) sideOutputLists.get(tag);
}
/**
* Get only the {@link StreamRecord StreamRecords} emitted by the operator.
*/
@SuppressWarnings("unchecked")
public List<StreamRecord<? extends OUT>> extractOutputStreamRecords() {
List<StreamRecord<? extends OUT>> resultElements = new LinkedList<>();
for (Object e: getOutput()) {
if (e instanceof StreamRecord) {
resultElements.add((StreamRecord<OUT>) e);
}
}
return resultElements;
}
/**
* Calls
* {@link StreamOperator#setup(StreamTask, StreamConfig, Output)} ()}
*/
public void setup() {
setup(null);
}
/**
* Calls
* {@link StreamOperator#setup(StreamTask, StreamConfig, Output)} ()}
*/
public void setup(TypeSerializer<OUT> outputSerializer) {
operator.setup(mockTask, config, new MockOutput(outputSerializer));
setupCalled = true;
}
public void initializeStateFromLegacyCheckpoint(String checkpointFilename) throws Exception {
FileInputStream fin = new FileInputStream(checkpointFilename);
StreamTaskState state = MigrationInstantiationUtil.deserializeObject(fin, ClassLoader.getSystemClassLoader());
fin.close();
if (!setupCalled) {
setup();
}
StreamStateHandle stateHandle = SavepointV0Serializer.convertOperatorAndFunctionState(state);
List<KeyedStateHandle> keyGroupStatesList = new ArrayList<>();
if (state.getKvStates() != null) {
KeyGroupsStateHandle keyedStateHandle = SavepointV0Serializer.convertKeyedBackendState(
state.getKvStates(),
environment.getTaskInfo().getIndexOfThisSubtask(),
0);
keyGroupStatesList.add(keyedStateHandle);
}
// finally calling the initializeState() with the legacy operatorStateHandles
initializeState(new OperatorStateHandles(0,
stateHandle,
keyGroupStatesList,
Collections.<KeyedStateHandle>emptyList(),
Collections.<OperatorStateHandle>emptyList(),
Collections.<OperatorStateHandle>emptyList()));
}
/**
* Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}.
* Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)}
* if it was not called before.
*
* <p>This will reshape the state handles to include only those key-group states
* in the local key-group range and the operator states that would be assigned to the local
* subtask.
*/
public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception {
if (!setupCalled) {
setup();
}
if (operatorStateHandles != null) {
int numKeyGroups = getEnvironment().getTaskInfo().getMaxNumberOfParallelSubtasks();
int numSubtasks = getEnvironment().getTaskInfo().getNumberOfParallelSubtasks();
int subtaskIndex = getEnvironment().getTaskInfo().getIndexOfThisSubtask();
// create a new OperatorStateHandles that only contains the state for our key-groups
List<KeyGroupRange> keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions(
numKeyGroups,
numSubtasks);
KeyGroupRange localKeyGroupRange =
keyGroupPartitions.get(subtaskIndex);
List<KeyedStateHandle> localManagedKeyGroupState = null;
if (operatorStateHandles.getManagedKeyedState() != null) {
localManagedKeyGroupState = StateAssignmentOperation.getKeyedStateHandles(
operatorStateHandles.getManagedKeyedState(),
localKeyGroupRange);
}
List<KeyedStateHandle> localRawKeyGroupState = null;
if (operatorStateHandles.getRawKeyedState() != null) {
localRawKeyGroupState = StateAssignmentOperation.getKeyedStateHandles(
operatorStateHandles.getRawKeyedState(),
localKeyGroupRange);
}
List<OperatorStateHandle> managedOperatorState = new ArrayList<>();
if (operatorStateHandles.getManagedOperatorState() != null) {
managedOperatorState.addAll(operatorStateHandles.getManagedOperatorState());
}
Collection<OperatorStateHandle> localManagedOperatorState = operatorStateRepartitioner.repartitionState(
managedOperatorState,
numSubtasks).get(subtaskIndex);
List<OperatorStateHandle> rawOperatorState = new ArrayList<>();
if (operatorStateHandles.getRawOperatorState() != null) {
rawOperatorState.addAll(operatorStateHandles.getRawOperatorState());
}
Collection<OperatorStateHandle> localRawOperatorState = operatorStateRepartitioner.repartitionState(
rawOperatorState,
numSubtasks).get(subtaskIndex);
OperatorStateHandles massagedOperatorStateHandles = new OperatorStateHandles(
0,
operatorStateHandles.getLegacyOperatorState(),
localManagedKeyGroupState,
localRawKeyGroupState,
localManagedOperatorState,
localRawOperatorState);
operator.initializeState(massagedOperatorStateHandles);
} else {
operator.initializeState(null);
}
initializeCalled = true;
}
/**
* Takes the different {@link OperatorStateHandles} created by calling {@link #snapshot(long, long)}
* on different instances of {@link AbstractStreamOperatorTestHarness} (each one representing one subtask)
* and repacks them into a single {@link OperatorStateHandles} so that the parallelism of the test
* can change arbitrarily (i.e. be able to scale both up and down).
*
* <p>
* After repacking the partial states, use {@link #initializeState(OperatorStateHandles)} to initialize
* a new instance with the resulting state. Bear in mind that for parallelism greater than one, you
* have to use the constructor {@link #AbstractStreamOperatorTestHarness(StreamOperator, int, int, int)}.
*
* <p>
* <b>NOTE: </b> each of the {@code handles} in the argument list is assumed to be from a single task of a single
* operator (i.e. chain length of one).
*
* <p>
* For an example of how to use it, have a look at
* {@link AbstractStreamOperatorTest#testStateAndTimerStateShufflingScalingDown()}.
*
* @param handles the different states to be merged.
* @return the resulting state, or {@code null} if no partial states are specified.
*/
public static OperatorStateHandles repackageState(OperatorStateHandles... handles) throws Exception {
if (handles.length < 1) {
return null;
} else if (handles.length == 1) {
return handles[0];
}
List<OperatorStateHandle> mergedManagedOperatorState = new ArrayList<>(handles.length);
List<OperatorStateHandle> mergedRawOperatorState = new ArrayList<>(handles.length);
List<KeyedStateHandle> mergedManagedKeyedState = new ArrayList<>(handles.length);
List<KeyedStateHandle> mergedRawKeyedState = new ArrayList<>(handles.length);
for (OperatorStateHandles handle: handles) {
Collection<OperatorStateHandle> managedOperatorState = handle.getManagedOperatorState();
Collection<OperatorStateHandle> rawOperatorState = handle.getRawOperatorState();
Collection<KeyedStateHandle> managedKeyedState = handle.getManagedKeyedState();
Collection<KeyedStateHandle> rawKeyedState = handle.getRawKeyedState();
if (managedOperatorState != null) {
mergedManagedOperatorState.addAll(managedOperatorState);
}
if (rawOperatorState != null) {
mergedRawOperatorState.addAll(rawOperatorState);
}
if (managedKeyedState != null) {
mergedManagedKeyedState.addAll(managedKeyedState);
}
if (rawKeyedState != null) {
mergedRawKeyedState.addAll(rawKeyedState);
}
}
return new OperatorStateHandles(
0,
null,
mergedManagedKeyedState,
mergedRawKeyedState,
mergedManagedOperatorState,
mergedRawOperatorState);
}
/**
* Calls {@link StreamOperator#open()}. This also
* calls {@link StreamOperator#setup(StreamTask, StreamConfig, Output)}
* if it was not called before.
*/
public void open() throws Exception {
if (!initializeCalled) {
initializeState(null);
}
operator.open();
}
/**
* Calls {@link StreamOperator#snapshotState(long, long, CheckpointOptions)}.
*/
public OperatorStateHandles snapshot(long checkpointId, long timestamp) throws Exception {
CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory(new JobID(), "test_op");
OperatorSnapshotResult operatorStateResult = operator.snapshotState(
checkpointId,
timestamp,
CheckpointOptions.forFullCheckpoint());
KeyedStateHandle keyedManaged = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getKeyedStateManagedFuture());
KeyedStateHandle keyedRaw = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getKeyedStateRawFuture());
OperatorStateHandle opManaged = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getOperatorStateManagedFuture());
OperatorStateHandle opRaw = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getOperatorStateRawFuture());
// also snapshot legacy state, if any
StreamStateHandle legacyStateHandle = null;
if (operator instanceof StreamCheckpointedOperator) {
final CheckpointStreamFactory.CheckpointStateOutputStream outStream =
streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp);
((StreamCheckpointedOperator) operator).snapshotState(outStream, checkpointId, timestamp);
legacyStateHandle = outStream.closeAndGetHandle();
}
return new OperatorStateHandles(
0,
legacyStateHandle,
keyedManaged != null ? Collections.singletonList(keyedManaged) : null,
keyedRaw != null ? Collections.singletonList(keyedRaw) : null,
opManaged != null ? Collections.singletonList(opManaged) : null,
opRaw != null ? Collections.singletonList(opRaw) : null);
}
/**
* Calls {@link StreamCheckpointedOperator#snapshotState(FSDataOutputStream, long, long)} if
* the operator implements this interface.
*/
@Deprecated
public StreamStateHandle snapshotLegacy(long checkpointId, long timestamp) throws Exception {
CheckpointStreamFactory.CheckpointStateOutputStream outStream = stateBackend.createStreamFactory(
new JobID(),
"test_op").createCheckpointStateOutputStream(checkpointId, timestamp);
if(operator instanceof StreamCheckpointedOperator) {
((StreamCheckpointedOperator) operator).snapshotState(outStream, checkpointId, timestamp);
return outStream.closeAndGetHandle();
} else {
throw new RuntimeException("Operator is not StreamCheckpointedOperator");
}
}
/**
* Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#notifyOfCompletedCheckpoint(long)} ()}
*/
public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {
operator.notifyOfCompletedCheckpoint(checkpointId);
}
/**
* Calls {@link StreamCheckpointedOperator#restoreState(FSDataInputStream)} if
* the operator implements this interface.
*/
@Deprecated
@SuppressWarnings("deprecation")
public void restore(StreamStateHandle snapshot) throws Exception {
if(operator instanceof StreamCheckpointedOperator) {
try (FSDataInputStream in = snapshot.openInputStream()) {
((StreamCheckpointedOperator) operator).restoreState(in);
}
} else {
throw new RuntimeException("Operator is not StreamCheckpointedOperator");
}
}
/**
* Calls close and dispose on the operator.
*/
public void close() throws Exception {
operator.close();
operator.dispose();
if (processingTimeService != null) {
processingTimeService.shutdownService();
}
setupCalled = false;
}
public void setProcessingTime(long time) throws Exception {
processingTimeService.setCurrentTime(time);
}
public long getProcessingTime() {
return processingTimeService.getCurrentProcessingTime();
}
public void setTimeCharacteristic(TimeCharacteristic timeCharacteristic) {
this.config.setTimeCharacteristic(timeCharacteristic);
}
public TimeCharacteristic getTimeCharacteristic() {
return this.config.getTimeCharacteristic();
}
public boolean wasFailedExternally() {
return wasFailedExternally;
}
@VisibleForTesting
public int numProcessingTimeTimers() {
if (operator instanceof AbstractStreamOperator) {
return ((AbstractStreamOperator) operator).numProcessingTimeTimers();
} else {
throw new UnsupportedOperationException();
}
}
@VisibleForTesting
public int numEventTimeTimers() {
if (operator instanceof AbstractStreamOperator) {
return ((AbstractStreamOperator) operator).numEventTimeTimers();
} else {
throw new UnsupportedOperationException();
}
}
private class MockOutput implements Output<StreamRecord<OUT>> {
private TypeSerializer<OUT> outputSerializer;
private TypeSerializer sideOutputSerializer;
MockOutput() {
this(null);
}
MockOutput(TypeSerializer<OUT> outputSerializer) {
this.outputSerializer = outputSerializer;
}
@Override
public void emitWatermark(Watermark mark) {
outputList.add(mark);
}
@Override
public void emitLatencyMarker(LatencyMarker latencyMarker) {
outputList.add(latencyMarker);
}
@Override
public void collect(StreamRecord<OUT> element) {
if (outputSerializer == null) {
outputSerializer = TypeExtractor.getForObject(element.getValue()).createSerializer(executionConfig);
}
if (element.hasTimestamp()) {
outputList.add(new StreamRecord<>(outputSerializer.copy(element.getValue()), element.getTimestamp()));
} else {
outputList.add(new StreamRecord<>(outputSerializer.copy(element.getValue())));
}
}
@Override
public <X> void collect(OutputTag<X> outputTag, StreamRecord<X> record) {
sideOutputSerializer = TypeExtractor.getForObject(record.getValue()).createSerializer(executionConfig);
ConcurrentLinkedQueue<Object> sideOutputList = sideOutputLists.get(outputTag);
if (sideOutputList == null) {
sideOutputList = new ConcurrentLinkedQueue<>();
sideOutputLists.put(outputTag, sideOutputList);
}
if (record.hasTimestamp()) {
sideOutputList.add(new StreamRecord<>(sideOutputSerializer.copy(record.getValue()), record.getTimestamp()));
} else {
sideOutputList.add(new StreamRecord<>(sideOutputSerializer.copy(record.getValue())));
}
}
@Override
public void close() {
// ignore
}
}
}