/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.runtime.operators;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.streaming.api.collector.selector.OutputSelector;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SplitStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.StreamMap;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.OperatorChain;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for stream operator chaining behaviour.
*/
public class StreamOperatorChainingTest {
// We have to use static fields because the sink functions will go through serialization
private static List<String> sink1Results;
private static List<String> sink2Results;
private static List<String> sink3Results;
@Test
public void testMultiChainingWithObjectReuse() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.getConfig().enableObjectReuse();
testMultiChaining(env);
}
@Test
public void testMultiChainingWithoutObjectReuse() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.getConfig().disableObjectReuse();
testMultiChaining(env);
}
/**
* Verify that multi-chaining works.
*/
private void testMultiChaining(StreamExecutionEnvironment env) throws Exception {
// the actual elements will not be used
DataStream<Integer> input = env.fromElements(1,2,3);
sink1Results = new ArrayList<>();
sink2Results = new ArrayList<>();
input = input
.map(new MapFunction<Integer, Integer>() {
private static final long serialVersionUID = 1L;
@Override
public Integer map(Integer value) throws Exception {
return value;
}
});
input
.map(new MapFunction<Integer, String>() {
private static final long serialVersionUID = 1L;
@Override
public String map(Integer value) throws Exception {
return "First: " + value;
}
})
.addSink(new SinkFunction<String>() {
private static final long serialVersionUID = 1L;
@Override
public void invoke(String value) throws Exception {
sink1Results.add(value);
}
});
input
.map(new MapFunction<Integer, String>() {
private static final long serialVersionUID = 1L;
@Override
public String map(Integer value) throws Exception {
return "Second: " + value;
}
})
.addSink(new SinkFunction<String>() {
private static final long serialVersionUID = 1L;
@Override
public void invoke(String value) throws Exception {
sink2Results.add(value);
}
});
// be build our own StreamTask and OperatorChain
JobGraph jobGraph = env.getStreamGraph().getJobGraph();
Assert.assertTrue(jobGraph.getVerticesSortedTopologicallyFromSources().size() == 2);
JobVertex chainedVertex = jobGraph.getVerticesSortedTopologicallyFromSources().get(1);
Configuration configuration = chainedVertex.getConfiguration();
StreamConfig streamConfig = new StreamConfig(configuration);
StreamMap<Integer, Integer> headOperator =
streamConfig.getStreamOperator(Thread.currentThread().getContextClassLoader());
StreamTask<Integer, StreamMap<Integer, Integer>> mockTask =
createMockTask(streamConfig, chainedVertex.getName());
OperatorChain<Integer, StreamMap<Integer, Integer>> operatorChain = new OperatorChain<>(mockTask);
headOperator.setup(mockTask, streamConfig, operatorChain.getChainEntryPoint());
for (StreamOperator<?> operator : operatorChain.getAllOperators()) {
if (operator != null) {
operator.open();
}
}
headOperator.processElement(new StreamRecord<>(1));
headOperator.processElement(new StreamRecord<>(2));
headOperator.processElement(new StreamRecord<>(3));
assertThat(sink1Results, contains("First: 1", "First: 2", "First: 3"));
assertThat(sink2Results, contains("Second: 1", "Second: 2", "Second: 3"));
}
@Test
public void testMultiChainingWithSplitWithObjectReuse() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.getConfig().enableObjectReuse();
testMultiChainingWithSplit(env);
}
@Test
public void testMultiChainingWithSplitWithoutObjectReuse() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.getConfig().disableObjectReuse();
testMultiChainingWithSplit(env);
}
/**
* Verify that multi-chaining works with object reuse enabled.
*/
private void testMultiChainingWithSplit(StreamExecutionEnvironment env) throws Exception {
// the actual elements will not be used
DataStream<Integer> input = env.fromElements(1,2,3);
sink1Results = new ArrayList<>();
sink2Results = new ArrayList<>();
sink3Results = new ArrayList<>();
input = input
.map(new MapFunction<Integer, Integer>(){
private static final long serialVersionUID = 1L;
@Override
public Integer map(Integer value) throws Exception {
return value;
}
});
SplitStream<Integer> split = input.split(new OutputSelector<Integer>() {
private static final long serialVersionUID = 1L;
@Override
public Iterable<String> select(Integer value) {
if (value.equals(1)) {
return Collections.singletonList("one");
} else {
return Collections.singletonList("other");
}
}
});
split.select("one")
.map(new MapFunction<Integer, String>() {
private static final long serialVersionUID = 1L;
@Override
public String map(Integer value) throws Exception {
return "First 1: " + value;
}
})
.addSink(new SinkFunction<String>() {
private static final long serialVersionUID = 1L;
@Override
public void invoke(String value) throws Exception {
sink1Results.add(value);
}
});
split.select("one")
.map(new MapFunction<Integer, String>() {
private static final long serialVersionUID = 1L;
@Override
public String map(Integer value) throws Exception {
return "First 2: " + value;
}
})
.addSink(new SinkFunction<String>() {
private static final long serialVersionUID = 1L;
@Override
public void invoke(String value) throws Exception {
sink2Results.add(value);
}
});
split.select("other")
.map(new MapFunction<Integer, String>() {
private static final long serialVersionUID = 1L;
@Override
public String map(Integer value) throws Exception {
return "Second: " + value;
}
})
.addSink(new SinkFunction<String>() {
private static final long serialVersionUID = 1L;
@Override
public void invoke(String value) throws Exception {
sink3Results.add(value);
}
});
// be build our own StreamTask and OperatorChain
JobGraph jobGraph = env.getStreamGraph().getJobGraph();
Assert.assertTrue(jobGraph.getVerticesSortedTopologicallyFromSources().size() == 2);
JobVertex chainedVertex = jobGraph.getVerticesSortedTopologicallyFromSources().get(1);
Configuration configuration = chainedVertex.getConfiguration();
StreamConfig streamConfig = new StreamConfig(configuration);
StreamMap<Integer, Integer> headOperator =
streamConfig.getStreamOperator(Thread.currentThread().getContextClassLoader());
StreamTask<Integer, StreamMap<Integer, Integer>> mockTask =
createMockTask(streamConfig, chainedVertex.getName());
OperatorChain<Integer, StreamMap<Integer, Integer>> operatorChain = new OperatorChain<>(mockTask);
headOperator.setup(mockTask, streamConfig, operatorChain.getChainEntryPoint());
for (StreamOperator<?> operator : operatorChain.getAllOperators()) {
if (operator != null) {
operator.open();
}
}
headOperator.processElement(new StreamRecord<>(1));
headOperator.processElement(new StreamRecord<>(2));
headOperator.processElement(new StreamRecord<>(3));
assertThat(sink1Results, contains("First 1: 1"));
assertThat(sink2Results, contains("First 2: 1"));
assertThat(sink3Results, contains("Second: 2", "Second: 3"));
}
private <IN, OT extends StreamOperator<IN>> StreamTask<IN, OT> createMockTask(StreamConfig streamConfig, String taskName) {
final Object checkpointLock = new Object();
final Environment env = new MockEnvironment(taskName, 3 * 1024 * 1024, new MockInputSplitProvider(), 1024);
@SuppressWarnings("unchecked")
StreamTask<IN, OT> mockTask = mock(StreamTask.class);
when(mockTask.getName()).thenReturn("Mock Task");
when(mockTask.getCheckpointLock()).thenReturn(checkpointLock);
when(mockTask.getConfiguration()).thenReturn(streamConfig);
when(mockTask.getEnvironment()).thenReturn(env);
when(mockTask.getExecutionConfig()).thenReturn(new ExecutionConfig().enableObjectReuse());
return mockTask;
}
}