/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.streaming.runtime;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.streaming.api.CheckpointingMode;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSink;
import org.apache.flink.streaming.api.datastream.IterativeStream;
import org.apache.flink.streaming.api.datastream.IterativeStream.ConnectedIterativeStreams;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.datastream.SplitStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.CoMapFunction;
import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.graph.StreamEdge;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.streaming.api.graph.StreamNode;
import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
import org.apache.flink.streaming.runtime.partitioner.ShufflePartitioner;
import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase;
import org.apache.flink.test.streaming.runtime.util.EvenOddOutputSelector;
import org.apache.flink.test.streaming.runtime.util.NoOpIntMap;
import org.apache.flink.test.streaming.runtime.util.ReceiveCheckNoOpSink;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MathUtils;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@SuppressWarnings({ "unchecked", "unused", "serial" })
public class IterateITCase extends StreamingMultipleProgramsTestBase {
private static final Logger LOG = LoggerFactory.getLogger(IterateITCase.class);
private static boolean iterated[];
@Test(expected = UnsupportedOperationException.class)
public void testIncorrectParallelism() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<Integer> source = env.fromElements(1, 10);
IterativeStream<Integer> iter1 = source.iterate();
SingleOutputStreamOperator<Integer> map1 = iter1.map(NoOpIntMap);
iter1.closeWith(map1).print();
}
@Test
public void testDoubleClosing() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// introduce dummy mapper to get to correct parallelism
DataStream<Integer> source = env.fromElements(1, 10).map(NoOpIntMap);
IterativeStream<Integer> iter1 = source.iterate();
iter1.closeWith(iter1.map(NoOpIntMap));
iter1.closeWith(iter1.map(NoOpIntMap));
}
@Test(expected = UnsupportedOperationException.class)
public void testDifferingParallelism() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// introduce dummy mapper to get to correct parallelism
DataStream<Integer> source = env.fromElements(1, 10)
.map(NoOpIntMap);
IterativeStream<Integer> iter1 = source.iterate();
iter1.closeWith(iter1.map(NoOpIntMap).setParallelism(DEFAULT_PARALLELISM / 2));
}
@Test(expected = UnsupportedOperationException.class)
public void testCoDifferingParallelism() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// introduce dummy mapper to get to correct parallelism
DataStream<Integer> source = env.fromElements(1, 10).map(NoOpIntMap);
ConnectedIterativeStreams<Integer, Integer> coIter = source.iterate().withFeedbackType(
Integer.class);
coIter.closeWith(coIter.map(NoOpIntCoMap).setParallelism(DEFAULT_PARALLELISM / 2));
}
@Test(expected = UnsupportedOperationException.class)
public void testClosingFromOutOfLoop() throws Exception {
// this test verifies that we cannot close an iteration with a DataStream that does not
// have the iteration in its predecessors
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// introduce dummy mapper to get to correct parallelism
DataStream<Integer> source = env.fromElements(1, 10).map(NoOpIntMap);
IterativeStream<Integer> iter1 = source.iterate();
IterativeStream<Integer> iter2 = source.iterate();
iter2.closeWith(iter1.map(NoOpIntMap));
}
@Test(expected = UnsupportedOperationException.class)
public void testCoIterClosingFromOutOfLoop() throws Exception {
// this test verifies that we cannot close an iteration with a DataStream that does not
// have the iteration in its predecessors
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// introduce dummy mapper to get to correct parallelism
DataStream<Integer> source = env.fromElements(1, 10).map(NoOpIntMap);
IterativeStream<Integer> iter1 = source.iterate();
ConnectedIterativeStreams<Integer, Integer> coIter = source.iterate().withFeedbackType(
Integer.class);
coIter.closeWith(iter1.map(NoOpIntMap));
}
@Test(expected = IllegalStateException.class)
public void testExecutionWithEmptyIteration() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<Integer> source = env.fromElements(1, 10).map(NoOpIntMap);
IterativeStream<Integer> iter1 = source.iterate();
iter1.map(NoOpIntMap).print();
env.execute();
}
@Test
public void testImmutabilityWithCoiteration() {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<Integer> source = env.fromElements(1, 10).map(NoOpIntMap); // for rebalance
IterativeStream<Integer> iter1 = source.iterate();
// Calling withFeedbackType should create a new iteration
ConnectedIterativeStreams<Integer, String> iter2 = iter1.withFeedbackType(String.class);
iter1.closeWith(iter1.map(NoOpIntMap)).print();
iter2.closeWith(iter2.map(NoOpCoMap)).print();
StreamGraph graph = env.getStreamGraph();
assertEquals(2, graph.getIterationSourceSinkPairs().size());
for (Tuple2<StreamNode, StreamNode> sourceSinkPair: graph.getIterationSourceSinkPairs()) {
assertEquals(sourceSinkPair.f0.getOutEdges().get(0).getTargetVertex(), sourceSinkPair.f1.getInEdges().get(0).getSourceVertex());
}
}
@Test
public void testmultipleHeadsTailsSimple() {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<Integer> source1 = env.fromElements(1, 2, 3, 4, 5)
.shuffle()
.map(NoOpIntMap).name("ParallelizeMapShuffle");
DataStream<Integer> source2 = env.fromElements(1, 2, 3, 4, 5)
.map(NoOpIntMap).name("ParallelizeMapRebalance");
IterativeStream<Integer> iter1 = source1.union(source2).iterate();
DataStream<Integer> head1 = iter1.map(NoOpIntMap).name("IterRebalanceMap").setParallelism(DEFAULT_PARALLELISM / 2);
DataStream<Integer> head2 = iter1.map(NoOpIntMap).name("IterForwardMap");
DataStreamSink<Integer> head3 = iter1.map(NoOpIntMap).setParallelism(DEFAULT_PARALLELISM / 2).addSink(new ReceiveCheckNoOpSink<Integer>());
DataStreamSink<Integer> head4 = iter1.map(NoOpIntMap).addSink(new ReceiveCheckNoOpSink<Integer>());
SplitStream<Integer> source3 = env.fromElements(1, 2, 3, 4, 5)
.map(NoOpIntMap).name("EvenOddSourceMap")
.split(new EvenOddOutputSelector());
iter1.closeWith(source3.select("even").union(
head1.rebalance().map(NoOpIntMap).broadcast(), head2.shuffle()));
StreamGraph graph = env.getStreamGraph();
JobGraph jg = graph.getJobGraph();
assertEquals(1, graph.getIterationSourceSinkPairs().size());
Tuple2<StreamNode, StreamNode> sourceSinkPair = graph.getIterationSourceSinkPairs().iterator().next();
StreamNode itSource = sourceSinkPair.f0;
StreamNode itSink = sourceSinkPair.f1;
assertEquals(4, itSource.getOutEdges().size());
assertEquals(3, itSink.getInEdges().size());
assertEquals(itSource.getParallelism(), itSink.getParallelism());
for (StreamEdge edge : itSource.getOutEdges()) {
if (edge.getTargetVertex().getOperatorName().equals("IterRebalanceMap")) {
assertTrue(edge.getPartitioner() instanceof RebalancePartitioner);
} else if (edge.getTargetVertex().getOperatorName().equals("IterForwardMap")) {
assertTrue(edge.getPartitioner() instanceof ForwardPartitioner);
}
}
for (StreamEdge edge : itSink.getInEdges()) {
if (graph.getStreamNode(edge.getSourceId()).getOperatorName().equals("ParallelizeMapShuffle")) {
assertTrue(edge.getPartitioner() instanceof ShufflePartitioner);
}
if (graph.getStreamNode(edge.getSourceId()).getOperatorName().equals("ParallelizeMapForward")) {
assertTrue(edge.getPartitioner() instanceof ForwardPartitioner);
}
if (graph.getStreamNode(edge.getSourceId()).getOperatorName().equals("EvenOddSourceMap")) {
assertTrue(edge.getPartitioner() instanceof ForwardPartitioner);
assertTrue(edge.getSelectedNames().contains("even"));
}
}
// Test co-location
JobVertex itSource1 = null;
JobVertex itSink1 = null;
for (JobVertex vertex : jg.getVertices()) {
if (vertex.getName().contains("IterationSource")) {
itSource1 = vertex;
} else if (vertex.getName().contains("IterationSink")) {
itSink1 = vertex;
}
}
assertTrue(itSource1.getCoLocationGroup() != null);
assertEquals(itSource1.getCoLocationGroup(), itSink1.getCoLocationGroup());
}
@Test
public void testmultipleHeadsTailsWithTailPartitioning() {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<Integer> source1 = env.fromElements(1, 2, 3, 4, 5)
.shuffle()
.map(NoOpIntMap);
DataStream<Integer> source2 = env.fromElements(1, 2, 3, 4, 5)
.map(NoOpIntMap);
IterativeStream<Integer> iter1 = source1.union(source2).iterate();
DataStream<Integer> head1 = iter1.map(NoOpIntMap).name("map1");
DataStream<Integer> head2 = iter1.map(NoOpIntMap)
.setParallelism(DEFAULT_PARALLELISM / 2)
.name("shuffle").rebalance();
DataStreamSink<Integer> head3 = iter1.map(NoOpIntMap).setParallelism(DEFAULT_PARALLELISM / 2)
.addSink(new ReceiveCheckNoOpSink<Integer>());
DataStreamSink<Integer> head4 = iter1.map(NoOpIntMap).addSink(new ReceiveCheckNoOpSink<Integer>());
SplitStream<Integer> source3 = env.fromElements(1, 2, 3, 4, 5)
.map(NoOpIntMap)
.name("split")
.split(new EvenOddOutputSelector());
iter1.closeWith(
source3.select("even").union(
head1.map(NoOpIntMap).name("bc").broadcast(),
head2.map(NoOpIntMap).shuffle()));
StreamGraph graph = env.getStreamGraph();
JobGraph jg = graph.getJobGraph();
assertEquals(1, graph.getIterationSourceSinkPairs().size());
Tuple2<StreamNode, StreamNode> sourceSinkPair = graph.getIterationSourceSinkPairs().iterator().next();
StreamNode itSource = sourceSinkPair.f0;
StreamNode itSink = sourceSinkPair.f1;
assertEquals(4, itSource.getOutEdges().size());
assertEquals(3, itSink.getInEdges().size());
assertEquals(itSource.getParallelism(), itSink.getParallelism());
for (StreamEdge edge : itSource.getOutEdges()) {
if (edge.getTargetVertex().getOperatorName().equals("map1")) {
assertTrue(edge.getPartitioner() instanceof ForwardPartitioner);
assertEquals(4, edge.getTargetVertex().getParallelism());
} else if (edge.getTargetVertex().getOperatorName().equals("shuffle")) {
assertTrue(edge.getPartitioner() instanceof RebalancePartitioner);
assertEquals(2, edge.getTargetVertex().getParallelism());
}
}
for (StreamEdge edge : itSink.getInEdges()) {
String tailName = edge.getSourceVertex().getOperatorName();
if (tailName.equals("split")) {
assertTrue(edge.getPartitioner() instanceof ForwardPartitioner);
assertTrue(edge.getSelectedNames().contains("even"));
} else if (tailName.equals("bc")) {
assertTrue(edge.getPartitioner() instanceof BroadcastPartitioner);
} else if (tailName.equals("shuffle")) {
assertTrue(edge.getPartitioner() instanceof ShufflePartitioner);
}
}
// Test co-location
JobVertex itSource1 = null;
JobVertex itSink1 = null;
for (JobVertex vertex : jg.getVertices()) {
if (vertex.getName().contains("IterationSource")) {
itSource1 = vertex;
} else if (vertex.getName().contains("IterationSink")) {
itSink1 = vertex;
}
}
assertTrue(itSource1.getCoLocationGroup() != null);
assertTrue(itSink1.getCoLocationGroup() != null);
assertEquals(itSource1.getCoLocationGroup(), itSink1.getCoLocationGroup());
}
@SuppressWarnings("rawtypes")
@Test
public void testSimpleIteration() throws Exception {
int numRetries = 5;
int timeoutScale = 1;
for (int numRetry = 0; numRetry < numRetries; numRetry++) {
try {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
iterated = new boolean[DEFAULT_PARALLELISM];
DataStream<Boolean> source = env.fromCollection(Collections.nCopies(DEFAULT_PARALLELISM * 2, false))
.map(NoOpBoolMap).name("ParallelizeMap");
IterativeStream<Boolean> iteration = source.iterate(3000 * timeoutScale);
DataStream<Boolean> increment = iteration.flatMap(new IterationHead()).map(NoOpBoolMap);
iteration.map(NoOpBoolMap).addSink(new ReceiveCheckNoOpSink());
iteration.closeWith(increment).addSink(new ReceiveCheckNoOpSink());
env.execute();
for (boolean iter : iterated) {
assertTrue(iter);
}
break; // success
} catch (Throwable t) {
LOG.info("Run " + (numRetry + 1) + "/" + numRetries + " failed", t);
if (numRetry >= numRetries - 1) {
throw t;
} else {
timeoutScale *= 2;
}
}
}
}
@Test
public void testCoIteration() throws Exception {
int numRetries = 5;
int timeoutScale = 1;
for (int numRetry = 0; numRetry < numRetries; numRetry++) {
try {
TestSink.collected = new ArrayList<>();
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(2);
DataStream<String> otherSource = env.fromElements("1000", "2000")
.map(NoOpStrMap).name("ParallelizeMap");
ConnectedIterativeStreams<Integer, String> coIt = env.fromElements(0, 0)
.map(NoOpIntMap).name("ParallelizeMap")
.iterate(2000 * timeoutScale)
.withFeedbackType("String");
try {
coIt.keyBy(1, 2);
fail();
} catch (InvalidProgramException e) {
// this is expected
}
DataStream<String> head = coIt
.flatMap(new RichCoFlatMapFunction<Integer, String, String>() {
private static final long serialVersionUID = 1L;
boolean seenFromSource = false;
@Override
public void flatMap1(Integer value, Collector<String> out) throws Exception {
out.collect(((Integer) (value + 1)).toString());
}
@Override
public void flatMap2(String value, Collector<String> out) throws Exception {
Integer intVal = Integer.valueOf(value);
if (intVal < 2) {
out.collect(((Integer) (intVal + 1)).toString());
}
if (intVal == 1000 || intVal == 2000) {
seenFromSource = true;
}
}
@Override
public void close() {
assertTrue(seenFromSource);
}
});
coIt.map(new CoMapFunction<Integer, String, String>() {
@Override
public String map1(Integer value) throws Exception {
return value.toString();
}
@Override
public String map2(String value) throws Exception {
return value;
}
}).addSink(new ReceiveCheckNoOpSink<String>());
coIt.closeWith(head.broadcast().union(otherSource));
head.addSink(new TestSink()).setParallelism(1);
assertEquals(1, env.getStreamGraph().getIterationSourceSinkPairs().size());
env.execute();
Collections.sort(TestSink.collected);
assertEquals(Arrays.asList("1", "1", "2", "2", "2", "2"), TestSink.collected);
break; // success
} catch (Throwable t) {
LOG.info("Run " + (numRetry + 1) + "/" + numRetries + " failed", t);
if (numRetry >= numRetries - 1) {
throw t;
} else {
timeoutScale *= 2;
}
}
}
}
/**
* This test relies on the hash function used by the {@link DataStream#keyBy}, which is
* assumed to be {@link MathUtils#murmurHash}.
*
* For the test to pass all FlatMappers must see at least two records in the iteration,
* which can only be achieved if the hashed values of the input keys map to a complete
* congruence system. Given that the test is designed for 3 parallel FlatMapper instances
* keys chosen from the [1,3] range are a suitable choice.
*/
@Test
public void testGroupByFeedback() throws Exception {
int numRetries = 5;
int timeoutScale = 1;
for (int numRetry = 0; numRetry < numRetries; numRetry++) {
try {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(DEFAULT_PARALLELISM - 1);
env.getConfig().setMaxParallelism(env.getParallelism());
KeySelector<Integer, Integer> key = new KeySelector<Integer, Integer>() {
@Override
public Integer getKey(Integer value) throws Exception {
return value % 3;
}
};
DataStream<Integer> source = env.fromElements(1, 2, 3)
.map(NoOpIntMap).name("ParallelizeMap");
IterativeStream<Integer> it = source.keyBy(key).iterate(3000 * timeoutScale);
DataStream<Integer> head = it.flatMap(new RichFlatMapFunction<Integer, Integer>() {
int received = 0;
int key = -1;
@Override
public void flatMap(Integer value, Collector<Integer> out) throws Exception {
received++;
if (key == -1) {
key = MathUtils.murmurHash(value % 3) % 3;
} else {
assertEquals(key, MathUtils.murmurHash(value % 3) % 3);
}
if (value > 0) {
out.collect(value - 1);
}
}
@Override
public void close() {
assertTrue(received > 1);
}
});
it.closeWith(head.keyBy(key).union(head.map(NoOpIntMap).keyBy(key))).addSink(new ReceiveCheckNoOpSink<Integer>());
env.execute();
break; // success
} catch (Throwable t) {
LOG.info("Run " + (numRetry + 1) + "/" + numRetries + " failed", t);
if (numRetry >= numRetries - 1) {
throw t;
} else {
timeoutScale *= 2;
}
}
}
}
@SuppressWarnings("deprecation")
@Test
public void testWithCheckPointing() throws Exception {
int numRetries = 5;
int timeoutScale = 1;
for (int numRetry = 0; numRetry < numRetries; numRetry++) {
try {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.enableCheckpointing();
DataStream<Boolean> source = env.fromCollection(Collections.nCopies(DEFAULT_PARALLELISM * 2, false))
.map(NoOpBoolMap).name("ParallelizeMap");
IterativeStream<Boolean> iteration = source.iterate(3000 * timeoutScale);
iteration.closeWith(iteration.flatMap(new IterationHead())).addSink(new ReceiveCheckNoOpSink<Boolean>());
try {
env.execute();
// this statement should never be reached
fail();
} catch (UnsupportedOperationException e) {
// expected behaviour
}
// Test force checkpointing
try {
env.enableCheckpointing(1, CheckpointingMode.EXACTLY_ONCE, false);
env.execute();
// this statement should never be reached
fail();
} catch (UnsupportedOperationException e) {
// expected behaviour
}
env.enableCheckpointing(1, CheckpointingMode.EXACTLY_ONCE, true);
env.getStreamGraph().getJobGraph();
break; // success
} catch (Throwable t) {
LOG.info("Run " + (numRetry + 1) + "/" + numRetries + " failed", t);
if (numRetry >= numRetries - 1) {
throw t;
} else {
timeoutScale *= 2;
}
}
}
}
public static final class IterationHead extends RichFlatMapFunction<Boolean, Boolean> {
public void flatMap(Boolean value, Collector<Boolean> out) throws Exception {
int indx = getRuntimeContext().getIndexOfThisSubtask();
if (value) {
iterated[indx] = true;
} else {
out.collect(true);
}
}
}
public static CoMapFunction<Integer, String, String> NoOpCoMap = new CoMapFunction<Integer, String, String>() {
public String map1(Integer value) throws Exception {
return value.toString();
}
public String map2(String value) throws Exception {
return value;
}
};
public static MapFunction<Integer, Integer> NoOpIntMap = new NoOpIntMap();
public static MapFunction<String, String> NoOpStrMap = new MapFunction<String, String>() {
public String map(String value) throws Exception {
return value;
}
};
public static CoMapFunction<Integer, Integer, Integer> NoOpIntCoMap = new CoMapFunction<Integer, Integer, Integer>() {
public Integer map1(Integer value) throws Exception {
return value;
}
public Integer map2(Integer value) throws Exception {
return value;
}
};
public static MapFunction<Boolean, Boolean> NoOpBoolMap = new MapFunction<Boolean, Boolean>() {
public Boolean map(Boolean value) throws Exception {
return value;
}
};
public static class TestSink implements SinkFunction<String> {
private static final long serialVersionUID = 1L;
public static List<String> collected = new ArrayList<String>();
@Override
public void invoke(String value) throws Exception {
collected.add(value);
}
}
}