/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.streaming.api; import java.lang.reflect.Method; import java.util.List; import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.functions.FilterFunction; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.FoldFunction; import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.operators.ResourceSpec; import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.GenericTypeInfo; import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.streaming.api.collector.selector.OutputSelector; import org.apache.flink.streaming.api.datastream.ConnectedStreams; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.DataStreamSink; import org.apache.flink.streaming.api.datastream.DataStreamSource; import org.apache.flink.streaming.api.datastream.KeyedStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.streaming.api.datastream.SplitStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.ProcessFunction; import org.apache.flink.streaming.api.functions.co.CoFlatMapFunction; import org.apache.flink.streaming.api.functions.co.CoMapFunction; import org.apache.flink.streaming.api.functions.sink.DiscardingSink; import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; import org.apache.flink.streaming.api.graph.StreamEdge; import org.apache.flink.streaming.api.graph.StreamGraph; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.ProcessOperator; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.KeyedProcessOperator; import org.apache.flink.streaming.api.windowing.assigners.GlobalWindows; import org.apache.flink.streaming.api.windowing.triggers.CountTrigger; import org.apache.flink.streaming.api.windowing.triggers.PurgingTrigger; import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; import org.apache.flink.streaming.runtime.partitioner.CustomPartitionerWrapper; import org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner; import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; import org.apache.flink.streaming.runtime.partitioner.GlobalPartitioner; import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner; import org.apache.flink.streaming.runtime.partitioner.ShufflePartitioner; import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.util.Collector; import org.hamcrest.core.StringStartsWith; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import static org.junit.Assert.*; @SuppressWarnings("serial") public class DataStreamTest { /** * Tests union functionality. This ensures that self-unions and unions of streams * with differing parallelism work. * * @throws Exception */ @Test public void testUnion() throws Exception { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(4); DataStream<Long> input1 = env.generateSequence(0, 0) .map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }); DataStream<Long> selfUnion = input1.union(input1).map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }); DataStream<Long> input6 = env.generateSequence(0, 0) .map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }); DataStream<Long> selfUnionDifferentPartition = input6.broadcast().union(input6).map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }); DataStream<Long> input2 = env.generateSequence(0, 0) .map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }).setParallelism(4); DataStream<Long> input3 = env.generateSequence(0, 0) .map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }).setParallelism(2); DataStream<Long> unionDifferingParallelism= input2.union(input3).map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }).setParallelism(4); DataStream<Long> input4 = env.generateSequence(0, 0) .map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }).setParallelism(2); DataStream<Long> input5 = env.generateSequence(0, 0) .map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }).setParallelism(4); DataStream<Long> unionDifferingPartitioning = input4.broadcast().union(input5).map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }).setParallelism(4); StreamGraph streamGraph = env.getStreamGraph(); // verify self union assertTrue(streamGraph.getStreamNode(selfUnion.getId()).getInEdges().size() == 2); for (StreamEdge edge: streamGraph.getStreamNode(selfUnion.getId()).getInEdges()) { assertTrue(edge.getPartitioner() instanceof ForwardPartitioner); } // verify self union with differnt partitioners assertTrue(streamGraph.getStreamNode(selfUnionDifferentPartition.getId()).getInEdges().size() == 2); boolean hasForward = false; boolean hasBroadcast = false; for (StreamEdge edge: streamGraph.getStreamNode(selfUnionDifferentPartition.getId()).getInEdges()) { if (edge.getPartitioner() instanceof ForwardPartitioner) { hasForward = true; } if (edge.getPartitioner() instanceof BroadcastPartitioner) { hasBroadcast = true; } } assertTrue(hasForward && hasBroadcast); // verify union of streams with differing parallelism assertTrue(streamGraph.getStreamNode(unionDifferingParallelism.getId()).getInEdges().size() == 2); for (StreamEdge edge: streamGraph.getStreamNode(unionDifferingParallelism.getId()).getInEdges()) { if (edge.getSourceId() == input2.getId()) { assertTrue(edge.getPartitioner() instanceof ForwardPartitioner); } else if (edge.getSourceId() == input3.getId()) { assertTrue(edge.getPartitioner() instanceof RebalancePartitioner); } else { fail("Wrong input edge."); } } // verify union of streams with differing partitionings assertTrue(streamGraph.getStreamNode(unionDifferingPartitioning.getId()).getInEdges().size() == 2); for (StreamEdge edge: streamGraph.getStreamNode(unionDifferingPartitioning.getId()).getInEdges()) { if (edge.getSourceId() == input4.getId()) { assertTrue(edge.getPartitioner() instanceof BroadcastPartitioner); } else if (edge.getSourceId() == input5.getId()) { assertTrue(edge.getPartitioner() instanceof ForwardPartitioner); } else { fail("Wrong input edge."); } } } /** * Tests {@link SingleOutputStreamOperator#name(String)} functionality. * * @throws Exception */ @Test public void testNaming() throws Exception { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStream<Long> dataStream1 = env.generateSequence(0, 0).name("testSource1") .map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }).name("testMap"); DataStream<Long> dataStream2 = env.generateSequence(0, 0).name("testSource2") .map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }).name("testMap"); dataStream1.connect(dataStream2) .flatMap(new CoFlatMapFunction<Long, Long, Long>() { @Override public void flatMap1(Long value, Collector<Long> out) throws Exception {} @Override public void flatMap2(Long value, Collector<Long> out) throws Exception {} }).name("testCoFlatMap") .windowAll(GlobalWindows.create()) .trigger(PurgingTrigger.of(CountTrigger.of(10))) .fold(0L, new FoldFunction<Long, Long>() { private static final long serialVersionUID = 1L; @Override public Long fold(Long accumulator, Long value) throws Exception { return null; } }) .name("testWindowFold") .print(); //test functionality through the operator names in the execution plan String plan = env.getExecutionPlan(); assertTrue(plan.contains("testSource1")); assertTrue(plan.contains("testSource2")); assertTrue(plan.contains("testMap")); assertTrue(plan.contains("testMap")); assertTrue(plan.contains("testCoFlatMap")); assertTrue(plan.contains("testWindowFold")); } /** * Tests that {@link DataStream#keyBy} and {@link DataStream#partitionCustom(Partitioner, int)} result in * different and correct topologies. Does the some for the {@link ConnectedStreams}. */ @Test @SuppressWarnings("unchecked") public void testPartitioning() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStream<Tuple2<Long, Long>> src1 = env.fromElements(new Tuple2<>(0L, 0L)); DataStream<Tuple2<Long, Long>> src2 = env.fromElements(new Tuple2<>(0L, 0L)); ConnectedStreams<Tuple2<Long, Long>, Tuple2<Long, Long>> connected = src1.connect(src2); //Testing DataStream grouping DataStream<Tuple2<Long, Long>> group1 = src1.keyBy(0); DataStream<Tuple2<Long, Long>> group2 = src1.keyBy(1, 0); DataStream<Tuple2<Long, Long>> group3 = src1.keyBy("f0"); DataStream<Tuple2<Long, Long>> group4 = src1.keyBy(new FirstSelector()); int id1 = createDownStreamId(group1); int id2 = createDownStreamId(group2); int id3 = createDownStreamId(group3); int id4 = createDownStreamId(group4); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), id1))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), id2))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), id3))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), id4))); assertTrue(isKeyed(group1)); assertTrue(isKeyed(group2)); assertTrue(isKeyed(group3)); assertTrue(isKeyed(group4)); //Testing DataStream partitioning DataStream<Tuple2<Long, Long>> partition1 = src1.keyBy(0); DataStream<Tuple2<Long, Long>> partition2 = src1.keyBy(1, 0); DataStream<Tuple2<Long, Long>> partition3 = src1.keyBy("f0"); DataStream<Tuple2<Long, Long>> partition4 = src1.keyBy(new FirstSelector()); int pid1 = createDownStreamId(partition1); int pid2 = createDownStreamId(partition2); int pid3 = createDownStreamId(partition3); int pid4 = createDownStreamId(partition4); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), pid1))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), pid2))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), pid3))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), pid4))); assertTrue(isKeyed(partition1)); assertTrue(isKeyed(partition3)); assertTrue(isKeyed(partition2)); assertTrue(isKeyed(partition4)); // Testing DataStream custom partitioning Partitioner<Long> longPartitioner = new Partitioner<Long>() { @Override public int partition(Long key, int numPartitions) { return 100; } }; DataStream<Tuple2<Long, Long>> customPartition1 = src1.partitionCustom(longPartitioner, 0); DataStream<Tuple2<Long, Long>> customPartition3 = src1.partitionCustom(longPartitioner, "f0"); DataStream<Tuple2<Long, Long>> customPartition4 = src1.partitionCustom(longPartitioner, new FirstSelector()); int cid1 = createDownStreamId(customPartition1); int cid2 = createDownStreamId(customPartition3); int cid3 = createDownStreamId(customPartition4); assertTrue(isCustomPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), cid1))); assertTrue(isCustomPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), cid2))); assertTrue(isCustomPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), cid3))); assertFalse(isKeyed(customPartition1)); assertFalse(isKeyed(customPartition3)); assertFalse(isKeyed(customPartition4)); //Testing ConnectedStreams grouping ConnectedStreams<Tuple2<Long, Long>, Tuple2<Long, Long>> connectedGroup1 = connected.keyBy(0, 0); Integer downStreamId1 = createDownStreamId(connectedGroup1); ConnectedStreams<Tuple2<Long, Long>, Tuple2<Long, Long>> connectedGroup2 = connected.keyBy(new int[]{0}, new int[]{0}); Integer downStreamId2 = createDownStreamId(connectedGroup2); ConnectedStreams<Tuple2<Long, Long>, Tuple2<Long, Long>> connectedGroup3 = connected.keyBy("f0", "f0"); Integer downStreamId3 = createDownStreamId(connectedGroup3); ConnectedStreams<Tuple2<Long, Long>, Tuple2<Long, Long>> connectedGroup4 = connected.keyBy(new String[]{"f0"}, new String[]{"f0"}); Integer downStreamId4 = createDownStreamId(connectedGroup4); ConnectedStreams<Tuple2<Long, Long>, Tuple2<Long, Long>> connectedGroup5 = connected.keyBy(new FirstSelector(), new FirstSelector()); Integer downStreamId5 = createDownStreamId(connectedGroup5); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), downStreamId1))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src2.getId(), downStreamId1))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), downStreamId2))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src2.getId(), downStreamId2))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), downStreamId3))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src2.getId(), downStreamId3))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), downStreamId4))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src2.getId(), downStreamId4))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), downStreamId5))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src2.getId(), downStreamId5))); assertTrue(isKeyed(connectedGroup1)); assertTrue(isKeyed(connectedGroup2)); assertTrue(isKeyed(connectedGroup3)); assertTrue(isKeyed(connectedGroup4)); assertTrue(isKeyed(connectedGroup5)); //Testing ConnectedStreams partitioning ConnectedStreams<Tuple2<Long, Long>, Tuple2<Long, Long>> connectedPartition1 = connected.keyBy(0, 0); Integer connectDownStreamId1 = createDownStreamId(connectedPartition1); ConnectedStreams<Tuple2<Long, Long>, Tuple2<Long, Long>> connectedPartition2 = connected.keyBy(new int[]{0}, new int[]{0}); Integer connectDownStreamId2 = createDownStreamId(connectedPartition2); ConnectedStreams<Tuple2<Long, Long>, Tuple2<Long, Long>> connectedPartition3 = connected.keyBy("f0", "f0"); Integer connectDownStreamId3 = createDownStreamId(connectedPartition3); ConnectedStreams<Tuple2<Long, Long>, Tuple2<Long, Long>> connectedPartition4 = connected.keyBy(new String[]{"f0"}, new String[]{"f0"}); Integer connectDownStreamId4 = createDownStreamId(connectedPartition4); ConnectedStreams<Tuple2<Long, Long>, Tuple2<Long, Long>> connectedPartition5 = connected.keyBy(new FirstSelector(), new FirstSelector()); Integer connectDownStreamId5 = createDownStreamId(connectedPartition5); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), connectDownStreamId1))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src2.getId(), connectDownStreamId1))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), connectDownStreamId2))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src2.getId(), connectDownStreamId2))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), connectDownStreamId3))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src2.getId(), connectDownStreamId3))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), connectDownStreamId4))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src2.getId(), connectDownStreamId4))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src1.getId(), connectDownStreamId5))); assertTrue(isPartitioned(env.getStreamGraph().getStreamEdges(src2.getId(), connectDownStreamId5))); assertTrue(isKeyed(connectedPartition1)); assertTrue(isKeyed(connectedPartition2)); assertTrue(isKeyed(connectedPartition3)); assertTrue(isKeyed(connectedPartition4)); assertTrue(isKeyed(connectedPartition5)); } /** * Tests whether parallelism gets set. */ @Test public void testParallelism() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStreamSource<Tuple2<Long, Long>> src = env.fromElements(new Tuple2<>(0L, 0L)); env.setParallelism(10); SingleOutputStreamOperator<Long> map = src.map(new MapFunction<Tuple2<Long, Long>, Long>() { @Override public Long map(Tuple2<Long, Long> value) throws Exception { return null; } }).name("MyMap"); DataStream<Long> windowed = map .windowAll(GlobalWindows.create()) .trigger(PurgingTrigger.of(CountTrigger.of(10))) .fold(0L, new FoldFunction<Long, Long>() { @Override public Long fold(Long accumulator, Long value) throws Exception { return null; } }); windowed.addSink(new DiscardingSink<Long>()); DataStreamSink<Long> sink = map.addSink(new SinkFunction<Long>() { private static final long serialVersionUID = 1L; @Override public void invoke(Long value) throws Exception { } }); assertEquals(1, env.getStreamGraph().getStreamNode(src.getId()).getParallelism()); assertEquals(10, env.getStreamGraph().getStreamNode(map.getId()).getParallelism()); assertEquals(1, env.getStreamGraph().getStreamNode(windowed.getId()).getParallelism()); assertEquals(10, env.getStreamGraph().getStreamNode(sink.getTransformation().getId()).getParallelism()); env.setParallelism(7); // Some parts, such as windowing rely on the fact that previous operators have a parallelism // set when instantiating the Discretizer. This would break if we dynamically changed // the parallelism of operations when changing the setting on the Execution Environment. assertEquals(1, env.getStreamGraph().getStreamNode(src.getId()).getParallelism()); assertEquals(10, env.getStreamGraph().getStreamNode(map.getId()).getParallelism()); assertEquals(1, env.getStreamGraph().getStreamNode(windowed.getId()).getParallelism()); assertEquals(10, env.getStreamGraph().getStreamNode(sink.getTransformation().getId()).getParallelism()); try { src.setParallelism(3); fail(); } catch (IllegalArgumentException success) { // do nothing } DataStreamSource<Long> parallelSource = env.generateSequence(0, 0); parallelSource.addSink(new DiscardingSink<Long>()); assertEquals(7, env.getStreamGraph().getStreamNode(parallelSource.getId()).getParallelism()); parallelSource.setParallelism(3); assertEquals(3, env.getStreamGraph().getStreamNode(parallelSource.getId()).getParallelism()); map.setParallelism(2); assertEquals(2, env.getStreamGraph().getStreamNode(map.getId()).getParallelism()); sink.setParallelism(4); assertEquals(4, env.getStreamGraph().getStreamNode(sink.getTransformation().getId()).getParallelism()); } /** * Tests whether resources get set. */ @Test public void testResources() throws Exception{ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); ResourceSpec minResource1 = new ResourceSpec(1.0, 100); ResourceSpec preferredResource1 = new ResourceSpec(2.0, 200); ResourceSpec minResource2 = new ResourceSpec(1.0, 200); ResourceSpec preferredResource2 = new ResourceSpec(2.0, 300); ResourceSpec minResource3 = new ResourceSpec(1.0, 300); ResourceSpec preferredResource3 = new ResourceSpec(2.0, 400); ResourceSpec minResource4 = new ResourceSpec(1.0, 400); ResourceSpec preferredResource4 = new ResourceSpec(2.0, 500); ResourceSpec minResource5 = new ResourceSpec(1.0, 500); ResourceSpec preferredResource5 = new ResourceSpec(2.0, 600); ResourceSpec minResource6 = new ResourceSpec(1.0, 600); ResourceSpec preferredResource6 = new ResourceSpec(2.0, 700); ResourceSpec minResource7 = new ResourceSpec(1.0, 700); ResourceSpec preferredResource7 = new ResourceSpec(2.0, 800); Method opMethod = SingleOutputStreamOperator.class.getDeclaredMethod("setResources", ResourceSpec.class, ResourceSpec.class); opMethod.setAccessible(true); Method sinkMethod = DataStreamSink.class.getDeclaredMethod("setResources", ResourceSpec.class, ResourceSpec.class); sinkMethod.setAccessible(true); DataStream<Long> source1 = env.generateSequence(0, 0); opMethod.invoke(source1, minResource1, preferredResource1); DataStream<Long> map1 = source1.map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }); opMethod.invoke(map1, minResource2, preferredResource2); DataStream<Long> source2 = env.generateSequence(0, 0); opMethod.invoke(source2, minResource3, preferredResource3); DataStream<Long> map2 = source2.map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return null; } }); opMethod.invoke(map2, minResource4, preferredResource4); DataStream<Long> connected = map1.connect(map2) .flatMap(new CoFlatMapFunction<Long, Long, Long>() { @Override public void flatMap1(Long value, Collector<Long> out) throws Exception { } @Override public void flatMap2(Long value, Collector<Long> out) throws Exception { } }); opMethod.invoke(connected, minResource5, preferredResource5); DataStream<Long> windowed = connected .windowAll(GlobalWindows.create()) .trigger(PurgingTrigger.of(CountTrigger.of(10))) .fold(0L, new FoldFunction<Long, Long>() { private static final long serialVersionUID = 1L; @Override public Long fold(Long accumulator, Long value) throws Exception { return null; } }); opMethod.invoke(windowed, minResource6, preferredResource6); DataStreamSink<Long> sink = windowed.print(); sinkMethod.invoke(sink, minResource7, preferredResource7); assertEquals(minResource1, env.getStreamGraph().getStreamNode(source1.getId()).getMinResources()); assertEquals(preferredResource1, env.getStreamGraph().getStreamNode(source1.getId()).getPreferredResources()); assertEquals(minResource2, env.getStreamGraph().getStreamNode(map1.getId()).getMinResources()); assertEquals(preferredResource2, env.getStreamGraph().getStreamNode(map1.getId()).getPreferredResources()); assertEquals(minResource3, env.getStreamGraph().getStreamNode(source2.getId()).getMinResources()); assertEquals(preferredResource3, env.getStreamGraph().getStreamNode(source2.getId()).getPreferredResources()); assertEquals(minResource4, env.getStreamGraph().getStreamNode(map2.getId()).getMinResources()); assertEquals(preferredResource4, env.getStreamGraph().getStreamNode(map2.getId()).getPreferredResources()); assertEquals(minResource5, env.getStreamGraph().getStreamNode(connected.getId()).getMinResources()); assertEquals(preferredResource5, env.getStreamGraph().getStreamNode(connected.getId()).getPreferredResources()); assertEquals(minResource6, env.getStreamGraph().getStreamNode(windowed.getId()).getMinResources()); assertEquals(preferredResource6, env.getStreamGraph().getStreamNode(windowed.getId()).getPreferredResources()); assertEquals(minResource7, env.getStreamGraph().getStreamNode(sink.getTransformation().getId()).getMinResources()); assertEquals(preferredResource7, env.getStreamGraph().getStreamNode(sink.getTransformation().getId()).getPreferredResources()); } @Test public void testTypeInfo() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStream<Long> src1 = env.generateSequence(0, 0); assertEquals(TypeExtractor.getForClass(Long.class), src1.getType()); DataStream<Tuple2<Integer, String>> map = src1.map(new MapFunction<Long, Tuple2<Integer, String>>() { @Override public Tuple2<Integer, String> map(Long value) throws Exception { return null; } }); assertEquals(TypeExtractor.getForObject(new Tuple2<>(0, "")), map.getType()); DataStream<String> window = map .windowAll(GlobalWindows.create()) .trigger(PurgingTrigger.of(CountTrigger.of(5))) .apply(new AllWindowFunction<Tuple2<Integer, String>, String, GlobalWindow>() { @Override public void apply(GlobalWindow window, Iterable<Tuple2<Integer, String>> values, Collector<String> out) throws Exception { } }); assertEquals(TypeExtractor.getForClass(String.class), window.getType()); DataStream<CustomPOJO> flatten = window .windowAll(GlobalWindows.create()) .trigger(PurgingTrigger.of(CountTrigger.of(5))) .fold(new CustomPOJO(), new FoldFunction<String, CustomPOJO>() { private static final long serialVersionUID = 1L; @Override public CustomPOJO fold(CustomPOJO accumulator, String value) throws Exception { return null; } }); assertEquals(TypeExtractor.getForClass(CustomPOJO.class), flatten.getType()); } /** * Verify that a {@link KeyedStream#process(ProcessFunction)} call is correctly translated to * an operator. */ @Test public void testKeyedProcessTranslation() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStreamSource<Long> src = env.generateSequence(0, 0); ProcessFunction<Long, Integer> processFunction = new ProcessFunction<Long, Integer>() { private static final long serialVersionUID = 1L; @Override public void processElement( Long value, Context ctx, Collector<Integer> out) throws Exception { } @Override public void onTimer( long timestamp, OnTimerContext ctx, Collector<Integer> out) throws Exception { } }; DataStream<Integer> processed = src .keyBy(new IdentityKeySelector<Long>()) .process(processFunction); processed.addSink(new DiscardingSink<Integer>()); assertEquals(processFunction, getFunctionForDataStream(processed)); assertTrue(getOperatorForDataStream(processed) instanceof KeyedProcessOperator); } /** * Verify that a {@link DataStream#process(ProcessFunction)} call is correctly translated to * an operator. */ @Test public void testProcessTranslation() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStreamSource<Long> src = env.generateSequence(0, 0); ProcessFunction<Long, Integer> processFunction = new ProcessFunction<Long, Integer>() { private static final long serialVersionUID = 1L; @Override public void processElement( Long value, Context ctx, Collector<Integer> out) throws Exception { } @Override public void onTimer( long timestamp, OnTimerContext ctx, Collector<Integer> out) throws Exception { } }; DataStream<Integer> processed = src .process(processFunction); processed.addSink(new DiscardingSink<Integer>()); assertEquals(processFunction, getFunctionForDataStream(processed)); assertTrue(getOperatorForDataStream(processed) instanceof ProcessOperator); } @Test public void operatorTest() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStreamSource<Long> src = env.generateSequence(0, 0); MapFunction<Long, Integer> mapFunction = new MapFunction<Long, Integer>() { @Override public Integer map(Long value) throws Exception { return null; } }; DataStream<Integer> map = src.map(mapFunction); map.addSink(new DiscardingSink<Integer>()); assertEquals(mapFunction, getFunctionForDataStream(map)); FlatMapFunction<Long, Integer> flatMapFunction = new FlatMapFunction<Long, Integer>() { private static final long serialVersionUID = 1L; @Override public void flatMap(Long value, Collector<Integer> out) throws Exception { } }; DataStream<Integer> flatMap = src.flatMap(flatMapFunction); flatMap.addSink(new DiscardingSink<Integer>()); assertEquals(flatMapFunction, getFunctionForDataStream(flatMap)); FilterFunction<Integer> filterFunction = new FilterFunction<Integer>() { @Override public boolean filter(Integer value) throws Exception { return false; } }; DataStream<Integer> unionFilter = map.union(flatMap) .filter(filterFunction); unionFilter.addSink(new DiscardingSink<Integer>()); assertEquals(filterFunction, getFunctionForDataStream(unionFilter)); try { env.getStreamGraph().getStreamEdges(map.getId(), unionFilter.getId()); } catch (RuntimeException e) { fail(e.getMessage()); } try { env.getStreamGraph().getStreamEdges(flatMap.getId(), unionFilter.getId()); } catch (RuntimeException e) { fail(e.getMessage()); } OutputSelector<Integer> outputSelector = new OutputSelector<Integer>() { @Override public Iterable<String> select(Integer value) { return null; } }; SplitStream<Integer> split = unionFilter.split(outputSelector); split.select("dummy").addSink(new DiscardingSink<Integer>()); List<OutputSelector<?>> outputSelectors = env.getStreamGraph().getStreamNode(unionFilter.getId()).getOutputSelectors(); assertEquals(1, outputSelectors.size()); assertEquals(outputSelector, outputSelectors.get(0)); DataStream<Integer> select = split.select("a"); DataStreamSink<Integer> sink = select.print(); StreamEdge splitEdge = env.getStreamGraph().getStreamEdges(unionFilter.getId(), sink.getTransformation().getId()).get(0); assertEquals("a", splitEdge.getSelectedNames().get(0)); ConnectedStreams<Integer, Integer> connect = map.connect(flatMap); CoMapFunction<Integer, Integer, String> coMapper = new CoMapFunction<Integer, Integer, String>() { private static final long serialVersionUID = 1L; @Override public String map1(Integer value) { return null; } @Override public String map2(Integer value) { return null; } }; DataStream<String> coMap = connect.map(coMapper); coMap.addSink(new DiscardingSink<String>()); assertEquals(coMapper, getFunctionForDataStream(coMap)); try { env.getStreamGraph().getStreamEdges(map.getId(), coMap.getId()); } catch (RuntimeException e) { fail(e.getMessage()); } try { env.getStreamGraph().getStreamEdges(flatMap.getId(), coMap.getId()); } catch (RuntimeException e) { fail(e.getMessage()); } } @Test public void sinkKeyTest() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStreamSink<Long> sink = env.generateSequence(1, 100).print(); assertTrue(env.getStreamGraph().getStreamNode(sink.getTransformation().getId()).getStatePartitioner1() == null); assertTrue(env.getStreamGraph().getStreamNode(sink.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof ForwardPartitioner); KeySelector<Long, Long> key1 = new KeySelector<Long, Long>() { private static final long serialVersionUID = 1L; @Override public Long getKey(Long value) throws Exception { return (long) 0; } }; DataStreamSink<Long> sink2 = env.generateSequence(1, 100).keyBy(key1).print(); assertNotNull(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStatePartitioner1()); assertNotNull(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStateKeySerializer()); assertNotNull(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStateKeySerializer()); assertEquals(key1, env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStatePartitioner1()); assertTrue(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof KeyGroupStreamPartitioner); KeySelector<Long, Long> key2 = new KeySelector<Long, Long>() { private static final long serialVersionUID = 1L; @Override public Long getKey(Long value) throws Exception { return (long) 0; } }; DataStreamSink<Long> sink3 = env.generateSequence(1, 100).keyBy(key2).print(); assertTrue(env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getStatePartitioner1() != null); assertEquals(key2, env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getStatePartitioner1()); assertTrue(env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof KeyGroupStreamPartitioner); } @Test public void testChannelSelectors() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStreamSource<Long> src = env.generateSequence(0, 0); DataStream<Long> broadcast = src.broadcast(); DataStreamSink<Long> broadcastSink = broadcast.print(); StreamPartitioner<?> broadcastPartitioner = env.getStreamGraph().getStreamEdges(src.getId(), broadcastSink.getTransformation().getId()).get(0).getPartitioner(); assertTrue(broadcastPartitioner instanceof BroadcastPartitioner); DataStream<Long> shuffle = src.shuffle(); DataStreamSink<Long> shuffleSink = shuffle.print(); StreamPartitioner<?> shufflePartitioner = env.getStreamGraph().getStreamEdges(src.getId(), shuffleSink.getTransformation().getId()).get(0).getPartitioner(); assertTrue(shufflePartitioner instanceof ShufflePartitioner); DataStream<Long> forward = src.forward(); DataStreamSink<Long> forwardSink = forward.print(); StreamPartitioner<?> forwardPartitioner = env.getStreamGraph().getStreamEdges(src.getId(), forwardSink.getTransformation().getId()).get(0).getPartitioner(); assertTrue(forwardPartitioner instanceof ForwardPartitioner); DataStream<Long> rebalance = src.rebalance(); DataStreamSink<Long> rebalanceSink = rebalance.print(); StreamPartitioner<?> rebalancePartitioner = env.getStreamGraph().getStreamEdges(src.getId(), rebalanceSink.getTransformation().getId()).get(0).getPartitioner(); assertTrue(rebalancePartitioner instanceof RebalancePartitioner); DataStream<Long> global = src.global(); DataStreamSink<Long> globalSink = global.print(); StreamPartitioner<?> globalPartitioner = env.getStreamGraph().getStreamEdges(src.getId(), globalSink.getTransformation().getId()).get(0).getPartitioner(); assertTrue(globalPartitioner instanceof GlobalPartitioner); } ///////////////////////////////////////////////////////////// // KeyBy testing ///////////////////////////////////////////////////////////// @Rule public ExpectedException expectedException = ExpectedException.none(); @Test public void testPrimitiveArrayKeyRejection() { KeySelector<Tuple2<Integer[], String>, int[]> keySelector = new KeySelector<Tuple2<Integer[], String>, int[]>() { @Override public int[] getKey(Tuple2<Integer[], String> value) throws Exception { int[] ks = new int[value.f0.length]; for (int i = 0; i < ks.length; i++) { ks[i] = value.f0[i]; } return ks; } }; testKeyRejection(keySelector, PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO); } @Test public void testBasicArrayKeyRejection() { KeySelector<Tuple2<Integer[], String>, Integer[]> keySelector = new KeySelector<Tuple2<Integer[], String>, Integer[]>() { @Override public Integer[] getKey(Tuple2<Integer[], String> value) throws Exception { return value.f0; } }; testKeyRejection(keySelector, BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO); } @Test public void testObjectArrayKeyRejection() { KeySelector<Tuple2<Integer[], String>, Object[]> keySelector = new KeySelector<Tuple2<Integer[], String>, Object[]>() { @Override public Object[] getKey(Tuple2<Integer[], String> value) throws Exception { Object[] ks = new Object[value.f0.length]; for (int i = 0; i < ks.length; i++) { ks[i] = new Object(); } return ks; } }; ObjectArrayTypeInfo<Object[], Object> keyTypeInfo = ObjectArrayTypeInfo.getInfoFor( Object[].class, new GenericTypeInfo<>(Object.class)); testKeyRejection(keySelector, keyTypeInfo); } private <K> void testKeyRejection(KeySelector<Tuple2<Integer[], String>, K> keySelector, TypeInformation<K> expectedKeyType) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStream<Tuple2<Integer[], String>> input = env.fromElements( new Tuple2<>(new Integer[] {1, 2}, "barfoo") ); Assert.assertEquals(expectedKeyType, TypeExtractor.getKeySelectorTypes(keySelector, input.getType())); // adjust the rule expectedException.expect(InvalidProgramException.class); expectedException.expectMessage(new StringStartsWith("Type " + expectedKeyType + " cannot be used as key.")); input.keyBy(keySelector); } //////////////// Composite Key Tests : POJOs //////////////// @Test public void testPOJOWithNestedArrayNoHashCodeKeyRejection() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStream<POJOWithHashCode> input = env.fromElements( new POJOWithHashCode(new int[] {1, 2})); TypeInformation<?> expectedTypeInfo = new TupleTypeInfo<Tuple1<int[]>>( PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO); // adjust the rule expectedException.expect(InvalidProgramException.class); expectedException.expectMessage(new StringStartsWith("Type " + expectedTypeInfo + " cannot be used as key.")); input.keyBy("id"); } @Test public void testPOJOWithNestedArrayAndHashCodeWorkAround() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStream<POJOWithHashCode> input = env.fromElements( new POJOWithHashCode(new int[] {1, 2})); input.keyBy(new KeySelector<POJOWithHashCode, POJOWithHashCode>() { @Override public POJOWithHashCode getKey(POJOWithHashCode value) throws Exception { return value; } }).addSink(new SinkFunction<POJOWithHashCode>() { @Override public void invoke(POJOWithHashCode value) throws Exception { Assert.assertEquals(value.getId(), new int[]{1, 2}); } }); } @Test public void testPOJOnoHashCodeKeyRejection() { KeySelector<POJOWithoutHashCode, POJOWithoutHashCode> keySelector = new KeySelector<POJOWithoutHashCode, POJOWithoutHashCode>() { @Override public POJOWithoutHashCode getKey(POJOWithoutHashCode value) throws Exception { return value; } }; StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStream<POJOWithoutHashCode> input = env.fromElements( new POJOWithoutHashCode(new int[] {1, 2})); // adjust the rule expectedException.expect(InvalidProgramException.class); input.keyBy(keySelector); } //////////////// Composite Key Tests : Tuples //////////////// @Test public void testTupleNestedArrayKeyRejection() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStream<Tuple2<Integer[], String>> input = env.fromElements( new Tuple2<>(new Integer[] {1, 2}, "test-test")); TypeInformation<?> expectedTypeInfo = new TupleTypeInfo<Tuple2<Integer[], String>>( BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); // adjust the rule expectedException.expect(InvalidProgramException.class); expectedException.expectMessage(new StringStartsWith("Type " + expectedTypeInfo + " cannot be used as key.")); input.keyBy(new KeySelector<Tuple2<Integer[],String>, Tuple2<Integer[],String>>() { @Override public Tuple2<Integer[], String> getKey(Tuple2<Integer[], String> value) throws Exception { return value; } }); } @Test public void testPrimitiveKeyAcceptance() throws Exception { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(1); env.setMaxParallelism(1); DataStream<Integer> input = env.fromElements(new Integer(10000)); KeyedStream<Integer, Object> keyedStream = input.keyBy(new KeySelector<Integer, Object>() { @Override public Object getKey(Integer value) throws Exception { return value; } }); keyedStream.addSink(new SinkFunction<Integer>() { @Override public void invoke(Integer value) throws Exception { Assert.assertEquals(10000L, (long) value); } }); } public static class POJOWithoutHashCode { private int[] id; public POJOWithoutHashCode() {} public POJOWithoutHashCode(int[] id) { this.id = id; } public int[] getId() { return id; } public void setId(int[] id) { this.id = id; } } public static class POJOWithHashCode extends POJOWithoutHashCode { public POJOWithHashCode() { } public POJOWithHashCode(int[] id) { super(id); } @Override public int hashCode() { int hash = 31; for (int i : getId()) { hash = 37 * hash + i; } return hash; } } ///////////////////////////////////////////////////////////// // Utilities ///////////////////////////////////////////////////////////// private static StreamOperator<?> getOperatorForDataStream(DataStream<?> dataStream) { StreamExecutionEnvironment env = dataStream.getExecutionEnvironment(); StreamGraph streamGraph = env.getStreamGraph(); return streamGraph.getStreamNode(dataStream.getId()).getOperator(); } private static Function getFunctionForDataStream(DataStream<?> dataStream) { AbstractUdfStreamOperator<?, ?> operator = (AbstractUdfStreamOperator<?, ?>) getOperatorForDataStream(dataStream); return operator.getUserFunction(); } private static Integer createDownStreamId(DataStream<?> dataStream) { return dataStream.print().getTransformation().getId(); } private static boolean isKeyed(DataStream<?> dataStream) { return dataStream instanceof KeyedStream; } @SuppressWarnings("rawtypes,unchecked") private static Integer createDownStreamId(ConnectedStreams dataStream) { SingleOutputStreamOperator<?> coMap = dataStream.map(new CoMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Object>() { private static final long serialVersionUID = 1L; @Override public Object map1(Tuple2<Long, Long> value) { return null; } @Override public Object map2(Tuple2<Long, Long> value) { return null; } }); coMap.addSink(new DiscardingSink()); return coMap.getId(); } private static boolean isKeyed(ConnectedStreams<?, ?> dataStream) { return (dataStream.getFirstInput() instanceof KeyedStream && dataStream.getSecondInput() instanceof KeyedStream); } private static boolean isPartitioned(List<StreamEdge> edges) { boolean result = true; for (StreamEdge edge: edges) { if (!(edge.getPartitioner() instanceof KeyGroupStreamPartitioner)) { result = false; } } return result; } private static boolean isCustomPartitioned(List<StreamEdge> edges) { boolean result = true; for (StreamEdge edge: edges) { if (!(edge.getPartitioner() instanceof CustomPartitionerWrapper)) { result = false; } } return result; } private static class FirstSelector implements KeySelector<Tuple2<Long, Long>, Long> { private static final long serialVersionUID = 1L; @Override public Long getKey(Tuple2<Long, Long> value) throws Exception { return value.f0; } } private static class IdentityKeySelector<T> implements KeySelector<T, T> { private static final long serialVersionUID = 1L; @Override public T getKey(T value) throws Exception { return value; } } public static class CustomPOJO { private String s; private int i; public CustomPOJO() { } public void setS(String s) { this.s = s; } public void setI(int i) { this.i = i; } public String getS() { return s; } public int getI() { return i; } } }