/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.streaming.api;
import org.apache.flink.api.common.functions.FoldFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.collector.selector.OutputSelector;
import org.apache.flink.streaming.api.datastream.AsyncDataStream;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SplitStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.async.AsyncFunction;
import org.apache.flink.streaming.api.functions.async.RichAsyncFunction;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.functions.async.collector.AsyncCollector;
import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase;
import org.apache.flink.util.MathUtils;
import org.junit.*;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class StreamingOperatorsITCase extends StreamingMultipleProgramsTestBase {
/**
* Tests the proper functioning of the streaming fold operator. For this purpose, a stream
* of Tuple2<Integer, Integer> is created. The stream is grouped according to the first tuple
* value. Each group is folded where the second tuple value is summed up.
*
* This test relies on the hash function used by the {@link DataStream#keyBy}, which is
* assumed to be {@link MathUtils#murmurHash}.
*/
@Test
public void testGroupedFoldOperation() throws Exception {
int numElements = 10;
final int numKeys = 2;
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<Tuple2<Integer, Integer>> sourceStream = env.addSource(new TupleSource(numElements, numKeys));
SplitStream<Tuple2<Integer, Integer>> splittedResult = sourceStream
.keyBy(0)
.fold(0, new FoldFunction<Tuple2<Integer, Integer>, Integer>() {
private static final long serialVersionUID = 4875723041825726082L;
@Override
public Integer fold(Integer accumulator, Tuple2<Integer, Integer> value) throws Exception {
return accumulator + value.f1;
}
}).map(new RichMapFunction<Integer, Tuple2<Integer, Integer>>() {
private static final long serialVersionUID = 8538355101606319744L;
int key = -1;
@Override
public Tuple2<Integer, Integer> map(Integer value) throws Exception {
if (key == -1){
key = MathUtils.murmurHash(value) % numKeys;
}
return new Tuple2<>(key, value);
}
}).split(new OutputSelector<Tuple2<Integer, Integer>>() {
private static final long serialVersionUID = -8439325199163362470L;
@Override
public Iterable<String> select(Tuple2<Integer, Integer> value) {
List<String> output = new ArrayList<>();
output.add(value.f0 + "");
return output;
}
});
final MemorySinkFunction sinkFunction1 = new MemorySinkFunction(0);
final List<Integer> actualResult1 = new ArrayList<>();
MemorySinkFunction.registerCollection(0, actualResult1);
splittedResult.select("0").map(new MapFunction<Tuple2<Integer,Integer>, Integer>() {
private static final long serialVersionUID = 2114608668010092995L;
@Override
public Integer map(Tuple2<Integer, Integer> value) throws Exception {
return value.f1;
}
}).addSink(sinkFunction1);
final MemorySinkFunction sinkFunction2 = new MemorySinkFunction(1);
final List<Integer> actualResult2 = new ArrayList<>();
MemorySinkFunction.registerCollection(1, actualResult2);
splittedResult.select("1").map(new MapFunction<Tuple2<Integer, Integer>, Integer>() {
private static final long serialVersionUID = 5631104389744681308L;
@Override
public Integer map(Tuple2<Integer, Integer> value) throws Exception {
return value.f1;
}
}).addSink(sinkFunction2);
Collection<Integer> expected1 = new ArrayList<>(10);
Collection<Integer> expected2 = new ArrayList<>(10);
int counter1 = 0;
int counter2 = 0;
for (int i = 0; i < numElements; i++) {
if (MathUtils.murmurHash(i) % numKeys == 0) {
counter1 += i;
expected1.add(counter1);
} else {
counter2 += i;
expected2.add(counter2);
}
}
env.execute();
Collections.sort(actualResult1);
Collections.sort(actualResult2);
Assert.assertEquals(expected1, actualResult1);
Assert.assertEquals(expected2, actualResult2);
MemorySinkFunction.clear();
}
/**
* Tests whether the fold operation can also be called with non Java serializable types.
*/
@Test
public void testFoldOperationWithNonJavaSerializableType() throws Exception {
final int numElements = 10;
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<Tuple2<Integer, NonSerializable>> input = env.addSource(new NonSerializableTupleSource(numElements));
final MemorySinkFunction sinkFunction = new MemorySinkFunction(0);
final ArrayList<Integer> actualResult = new ArrayList<>();
MemorySinkFunction.registerCollection(0, actualResult);
input
.keyBy(0)
.fold(
new NonSerializable(42),
new FoldFunction<Tuple2<Integer, NonSerializable>, NonSerializable>() {
private static final long serialVersionUID = 2705497830143608897L;
@Override
public NonSerializable fold(NonSerializable accumulator, Tuple2<Integer, NonSerializable> value) throws Exception {
return new NonSerializable(accumulator.value + value.f1.value);
}
})
.map(new MapFunction<NonSerializable, Integer>() {
private static final long serialVersionUID = 6906984044674568945L;
@Override
public Integer map(NonSerializable value) throws Exception {
return value.value;
}
})
.addSink(sinkFunction);
Collection<Integer> expected = new ArrayList<>(10);
for (int i = 0; i < numElements; i++) {
expected.add(42 + i );
}
env.execute();
Collections.sort(actualResult);
Assert.assertEquals(expected, actualResult);
MemorySinkFunction.clear();
}
/**
* Tests the basic functionality of the AsyncWaitOperator: Processing a limited stream of
* elements by doubling their value. This is tested in for the ordered and unordered mode.
*/
@Test
public void testAsyncWaitOperator() throws Exception {
final int numElements = 5;
final long timeout = 1000L;
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<Tuple2<Integer, NonSerializable>> input = env.addSource(new NonSerializableTupleSource(numElements));
AsyncFunction<Tuple2<Integer, NonSerializable>, Integer> function = new RichAsyncFunction<Tuple2<Integer, NonSerializable>, Integer>() {
private static final long serialVersionUID = 7000343199829487985L;
transient ExecutorService executorService;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
executorService = Executors.newFixedThreadPool(numElements);
}
@Override
public void close() throws Exception {
super.close();
executorService.shutdownNow();
}
@Override
public void asyncInvoke(final Tuple2<Integer, NonSerializable> input,
final AsyncCollector<Integer> collector) throws Exception {
executorService.submit(new Runnable() {
@Override
public void run() {
collector.collect(Collections.singletonList(input.f0 + input.f0));
}
});
}
};
DataStream<Integer> orderedResult = AsyncDataStream.orderedWait(
input,
function,
timeout,
TimeUnit.MILLISECONDS,
2).setParallelism(1);
// save result from ordered process
final MemorySinkFunction sinkFunction1 = new MemorySinkFunction(0);
final List<Integer> actualResult1 = new ArrayList<>(numElements);
MemorySinkFunction.registerCollection(0, actualResult1);
orderedResult.addSink(sinkFunction1).setParallelism(1);
DataStream<Integer> unorderedResult = AsyncDataStream.unorderedWait(
input,
function,
timeout,
TimeUnit.MILLISECONDS,
2);
// save result from unordered process
final MemorySinkFunction sinkFunction2 = new MemorySinkFunction(1);
final List<Integer> actualResult2 = new ArrayList<>(numElements);
MemorySinkFunction.registerCollection(1, actualResult2);
unorderedResult.addSink(sinkFunction2);
Collection<Integer> expected = new ArrayList<>(10);
for (int i = 0; i < numElements; i++) {
expected.add(i+i);
}
env.execute();
Assert.assertEquals(expected, actualResult1);
Collections.sort(actualResult2);
Assert.assertEquals(expected, actualResult2);
MemorySinkFunction.clear();
}
private static class NonSerializable {
// This makes the type non-serializable
private final Object obj = new Object();
private final int value;
public NonSerializable(int value) {
this.value = value;
}
}
private static class NonSerializableTupleSource implements SourceFunction<Tuple2<Integer, NonSerializable>> {
private static final long serialVersionUID = 3949171986015451520L;
private final int numElements;
public NonSerializableTupleSource(int numElements) {
this.numElements = numElements;
}
@Override
public void run(SourceContext<Tuple2<Integer, NonSerializable>> ctx) throws Exception {
for (int i = 0; i < numElements; i++) {
ctx.collect(new Tuple2<>(i, new NonSerializable(i)));
}
}
@Override
public void cancel() {}
}
private static class TupleSource implements SourceFunction<Tuple2<Integer, Integer>> {
private static final long serialVersionUID = -8110466235852024821L;
private final int numElements;
private final int numKeys;
public TupleSource(int numElements, int numKeys) {
this.numElements = numElements;
this.numKeys = numKeys;
}
@Override
public void run(SourceContext<Tuple2<Integer, Integer>> ctx) throws Exception {
for (int i = 0; i < numElements; i++) {
// keys '1' and '2' hash to different buckets
Tuple2<Integer, Integer> result = new Tuple2<>(1 + (MathUtils.murmurHash(i) % numKeys), i);
ctx.collect(result);
}
}
@Override
public void cancel() {
}
}
private static class MemorySinkFunction implements SinkFunction<Integer> {
private static Map<Integer, Collection<Integer>> collections = new ConcurrentHashMap<>();
private static final long serialVersionUID = -8815570195074103860L;
private final int key;
public MemorySinkFunction(int key) {
this.key = key;
}
@Override
public void invoke(Integer value) throws Exception {
Collection<Integer> collection = collections.get(key);
synchronized (collection) {
collection.add(value);
}
}
public static void registerCollection(int key, Collection<Integer> collection) {
collections.put(key, collection);
}
public static void clear() {
collections.clear();
}
}
}