/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.api;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.api.functions.source.FromElementsFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.util.Collector;
import org.apache.flink.util.SplittableIterator;
import org.junit.Assert;
import org.junit.Test;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class StreamExecutionEnvironmentTest {
@Test
public void fromElementsWithBaseTypeTest1() {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.fromElements(ParentClass.class, new SubClass(1, "Java"), new ParentClass(1, "hello"));
}
@Test(expected = IllegalArgumentException.class)
public void fromElementsWithBaseTypeTest2() {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.fromElements(SubClass.class, new SubClass(1, "Java"), new ParentClass(1, "hello"));
}
@Test
@SuppressWarnings("unchecked")
public void testFromCollectionParallelism() {
try {
TypeInformation<Integer> typeInfo = BasicTypeInfo.INT_TYPE_INFO;
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStreamSource<Integer> dataStream1 = env.fromCollection(new DummySplittableIterator<Integer>(), typeInfo);
try {
dataStream1.setParallelism(4);
fail("should throw an exception");
}
catch (IllegalArgumentException e) {
// expected
}
dataStream1.addSink(new DiscardingSink<Integer>());
DataStreamSource<Integer> dataStream2 = env.fromParallelCollection(new DummySplittableIterator<Integer>(),
typeInfo).setParallelism(4);
dataStream2.addSink(new DiscardingSink<Integer>());
env.getExecutionPlan();
assertEquals("Parallelism of collection source must be 1.", 1, env.getStreamGraph().getStreamNode(dataStream1.getId()).getParallelism());
assertEquals("Parallelism of parallel collection source must be 4.",
4,
env.getStreamGraph().getStreamNode(dataStream2.getId()).getParallelism());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testSources() {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
SourceFunction<Integer> srcFun = new SourceFunction<Integer>() {
private static final long serialVersionUID = 1L;
@Override
public void run(SourceContext<Integer> ctx) throws Exception {
}
@Override
public void cancel() {
}
};
DataStreamSource<Integer> src1 = env.addSource(srcFun);
src1.addSink(new DiscardingSink<Integer>());
assertEquals(srcFun, getFunctionFromDataSource(src1));
List<Long> list = Arrays.asList(0L, 1L, 2L);
DataStreamSource<Long> src2 = env.generateSequence(0, 2);
assertTrue(getFunctionFromDataSource(src2) instanceof StatefulSequenceSource);
DataStreamSource<Long> src3 = env.fromElements(0L, 1L, 2L);
assertTrue(getFunctionFromDataSource(src3) instanceof FromElementsFunction);
DataStreamSource<Long> src4 = env.fromCollection(list);
assertTrue(getFunctionFromDataSource(src4) instanceof FromElementsFunction);
}
@Test
public void testParallelismBounds() {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
SourceFunction<Integer> srcFun = new SourceFunction<Integer>() {
private static final long serialVersionUID = 1L;
@Override
public void run(SourceContext<Integer> ctx) throws Exception {
}
@Override
public void cancel() {
}
};
SingleOutputStreamOperator<Object> operator =
env.addSource(srcFun).flatMap(new FlatMapFunction<Integer, Object>() {
private static final long serialVersionUID = 1L;
@Override
public void flatMap(Integer value, Collector<Object> out) throws Exception {
}
});
// default value for max parallelism
Assert.assertEquals(-1, operator.getTransformation().getMaxParallelism());
// bounds for parallelism 1
try {
operator.setParallelism(0);
Assert.fail();
} catch (IllegalArgumentException expected) {
}
// bounds for parallelism 2
operator.setParallelism(1);
Assert.assertEquals(1, operator.getParallelism());
// bounds for parallelism 3
operator.setParallelism(1 << 15);
Assert.assertEquals(1 << 15, operator.getParallelism());
// default value after generating
env.getStreamGraph().getJobGraph();
Assert.assertEquals(-1, operator.getTransformation().getMaxParallelism());
// configured value after generating
env.setMaxParallelism(42);
env.getStreamGraph().getJobGraph();
Assert.assertEquals(42, operator.getTransformation().getMaxParallelism());
// bounds configured parallelism 1
try {
env.setMaxParallelism(0);
Assert.fail();
} catch (IllegalArgumentException expected) {
}
// bounds configured parallelism 2
try {
env.setMaxParallelism(1 + (1 << 15));
Assert.fail();
} catch (IllegalArgumentException expected) {
}
// bounds for max parallelism 1
try {
operator.setMaxParallelism(0);
Assert.fail();
} catch (IllegalArgumentException expected) {
}
// bounds for max parallelism 2
try {
operator.setMaxParallelism(1 + (1 << 15));
Assert.fail();
} catch (IllegalArgumentException expected) {
}
// bounds for max parallelism 3
operator.setMaxParallelism(1);
Assert.assertEquals(1, operator.getTransformation().getMaxParallelism());
// bounds for max parallelism 4
operator.setMaxParallelism(1 << 15);
Assert.assertEquals(1 << 15, operator.getTransformation().getMaxParallelism());
// override config
env.getStreamGraph().getJobGraph();
Assert.assertEquals(1 << 15 , operator.getTransformation().getMaxParallelism());
}
/////////////////////////////////////////////////////////////
// Utilities
/////////////////////////////////////////////////////////////
private static StreamOperator<?> getOperatorFromDataStream(DataStream<?> dataStream) {
StreamExecutionEnvironment env = dataStream.getExecutionEnvironment();
StreamGraph streamGraph = env.getStreamGraph();
return streamGraph.getStreamNode(dataStream.getId()).getOperator();
}
@SuppressWarnings("unchecked")
private static <T> SourceFunction<T> getFunctionFromDataSource(DataStreamSource<T> dataStreamSource) {
dataStreamSource.addSink(new DiscardingSink<T>());
AbstractUdfStreamOperator<?, ?> operator =
(AbstractUdfStreamOperator<?, ?>) getOperatorFromDataStream(dataStreamSource);
return (SourceFunction<T>) operator.getUserFunction();
}
public static class DummySplittableIterator<T> extends SplittableIterator<T> {
private static final long serialVersionUID = 1312752876092210499L;
@SuppressWarnings("unchecked")
@Override
public Iterator<T>[] split(int numPartitions) {
return (Iterator<T>[]) new Iterator<?>[0];
}
@Override
public int getMaximumNumberOfSplits() {
return 0;
}
@Override
public boolean hasNext() {
return false;
}
@Override
public T next() {
throw new NoSuchElementException();
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
}
public static class ParentClass {
int num;
String string;
public ParentClass(int num, String string) {
this.num = num;
this.string = string;
}
}
public static class SubClass extends ParentClass{
public SubClass(int num, String string) {
super(num, string);
}
}
}