/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.streaming.runtime;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase;
import org.apache.flink.test.streaming.runtime.util.NoOpIntMap;
import org.apache.flink.test.streaming.runtime.util.TestListResultSink;
import org.junit.Test;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
/**
* IT case that tests the different stream partitioning schemes.
*/
@SuppressWarnings("serial")
public class PartitionerITCase extends StreamingMultipleProgramsTestBase {
@Test(expected = UnsupportedOperationException.class)
public void testForwardFailsLowToHighParallelism() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<Integer> src = env.fromElements(1, 2, 3);
// this doesn't work because it goes from 1 to 3
src.forward().map(new NoOpIntMap());
env.execute();
}
@Test(expected = UnsupportedOperationException.class)
public void testForwardFailsHightToLowParallelism() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// this does a rebalance that works
DataStream<Integer> src = env.fromElements(1, 2, 3).map(new NoOpIntMap());
// this doesn't work because it goes from 3 to 1
src.forward().map(new NoOpIntMap()).setParallelism(1);
env.execute();
}
@Test
public void partitionerTest() {
TestListResultSink<Tuple2<Integer, String>> hashPartitionResultSink =
new TestListResultSink<Tuple2<Integer, String>>();
TestListResultSink<Tuple2<Integer, String>> customPartitionResultSink =
new TestListResultSink<Tuple2<Integer, String>>();
TestListResultSink<Tuple2<Integer, String>> broadcastPartitionResultSink =
new TestListResultSink<Tuple2<Integer, String>>();
TestListResultSink<Tuple2<Integer, String>> forwardPartitionResultSink =
new TestListResultSink<Tuple2<Integer, String>>();
TestListResultSink<Tuple2<Integer, String>> rebalancePartitionResultSink =
new TestListResultSink<Tuple2<Integer, String>>();
TestListResultSink<Tuple2<Integer, String>> globalPartitionResultSink =
new TestListResultSink<Tuple2<Integer, String>>();
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(3);
DataStream<Tuple1<String>> src = env.fromElements(
new Tuple1<String>("a"),
new Tuple1<String>("b"),
new Tuple1<String>("b"),
new Tuple1<String>("a"),
new Tuple1<String>("a"),
new Tuple1<String>("c"),
new Tuple1<String>("a")
);
// partition by hash
src
.keyBy(0)
.map(new SubtaskIndexAssigner())
.addSink(hashPartitionResultSink);
// partition custom
DataStream<Tuple2<Integer, String>> partitionCustom = src
.partitionCustom(new Partitioner<String>() {
@Override
public int partition(String key, int numPartitions) {
if (key.equals("c")) {
return 2;
} else {
return 0;
}
}
}, 0)
.map(new SubtaskIndexAssigner());
partitionCustom.addSink(customPartitionResultSink);
// partition broadcast
src.broadcast().map(new SubtaskIndexAssigner()).addSink(broadcastPartitionResultSink);
// partition rebalance
src.rebalance().map(new SubtaskIndexAssigner()).addSink(rebalancePartitionResultSink);
// partition forward
src.map(new MapFunction<Tuple1<String>, Tuple1<String>>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple1<String> map(Tuple1<String> value) throws Exception {
return value;
}
})
.forward()
.map(new SubtaskIndexAssigner())
.addSink(forwardPartitionResultSink);
// partition global
src.global().map(new SubtaskIndexAssigner()).addSink(globalPartitionResultSink);
try {
env.execute();
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
List<Tuple2<Integer, String>> hashPartitionResult = hashPartitionResultSink.getResult();
List<Tuple2<Integer, String>> customPartitionResult = customPartitionResultSink.getResult();
List<Tuple2<Integer, String>> broadcastPartitionResult = broadcastPartitionResultSink.getResult();
List<Tuple2<Integer, String>> forwardPartitionResult = forwardPartitionResultSink.getResult();
List<Tuple2<Integer, String>> rebalancePartitionResult = rebalancePartitionResultSink.getResult();
List<Tuple2<Integer, String>> globalPartitionResult = globalPartitionResultSink.getResult();
verifyHashPartitioning(hashPartitionResult);
verifyCustomPartitioning(customPartitionResult);
verifyBroadcastPartitioning(broadcastPartitionResult);
verifyRebalancePartitioning(forwardPartitionResult);
verifyRebalancePartitioning(rebalancePartitionResult);
verifyGlobalPartitioning(globalPartitionResult);
}
private static void verifyHashPartitioning(List<Tuple2<Integer, String>> hashPartitionResult) {
HashMap<String, Integer> verifier = new HashMap<String, Integer>();
for (Tuple2<Integer, String> elem : hashPartitionResult) {
Integer subtaskIndex = verifier.get(elem.f1);
if (subtaskIndex == null) {
verifier.put(elem.f1, elem.f0);
} else if (!Objects.equals(subtaskIndex, elem.f0)) {
fail();
}
}
}
private static void verifyCustomPartitioning(List<Tuple2<Integer, String>> customPartitionResult) {
for (Tuple2<Integer, String> stringWithSubtask : customPartitionResult) {
if (stringWithSubtask.f1.equals("c")) {
assertEquals(new Integer(2), stringWithSubtask.f0);
} else {
assertEquals(new Integer(0), stringWithSubtask.f0);
}
}
}
private static void verifyBroadcastPartitioning(List<Tuple2<Integer, String>> broadcastPartitionResult) {
List<Tuple2<Integer, String>> expected = Arrays.asList(
new Tuple2<Integer, String>(0, "a"),
new Tuple2<Integer, String>(0, "b"),
new Tuple2<Integer, String>(0, "b"),
new Tuple2<Integer, String>(0, "a"),
new Tuple2<Integer, String>(0, "a"),
new Tuple2<Integer, String>(0, "c"),
new Tuple2<Integer, String>(0, "a"),
new Tuple2<Integer, String>(1, "a"),
new Tuple2<Integer, String>(1, "b"),
new Tuple2<Integer, String>(1, "b"),
new Tuple2<Integer, String>(1, "a"),
new Tuple2<Integer, String>(1, "a"),
new Tuple2<Integer, String>(1, "c"),
new Tuple2<Integer, String>(1, "a"),
new Tuple2<Integer, String>(2, "a"),
new Tuple2<Integer, String>(2, "b"),
new Tuple2<Integer, String>(2, "b"),
new Tuple2<Integer, String>(2, "a"),
new Tuple2<Integer, String>(2, "a"),
new Tuple2<Integer, String>(2, "c"),
new Tuple2<Integer, String>(2, "a"));
assertEquals(
new HashSet<Tuple2<Integer, String>>(expected),
new HashSet<Tuple2<Integer, String>>(broadcastPartitionResult));
}
private static void verifyRebalancePartitioning(List<Tuple2<Integer, String>> rebalancePartitionResult) {
List<Tuple2<Integer, String>> expected = Arrays.asList(
new Tuple2<Integer, String>(0, "a"),
new Tuple2<Integer, String>(1, "b"),
new Tuple2<Integer, String>(2, "b"),
new Tuple2<Integer, String>(0, "a"),
new Tuple2<Integer, String>(1, "a"),
new Tuple2<Integer, String>(2, "c"),
new Tuple2<Integer, String>(0, "a"));
assertEquals(
new HashSet<Tuple2<Integer, String>>(expected),
new HashSet<Tuple2<Integer, String>>(rebalancePartitionResult));
}
private static void verifyGlobalPartitioning(List<Tuple2<Integer, String>> globalPartitionResult) {
List<Tuple2<Integer, String>> expected = Arrays.asList(
new Tuple2<Integer, String>(0, "a"),
new Tuple2<Integer, String>(0, "b"),
new Tuple2<Integer, String>(0, "b"),
new Tuple2<Integer, String>(0, "a"),
new Tuple2<Integer, String>(0, "a"),
new Tuple2<Integer, String>(0, "c"),
new Tuple2<Integer, String>(0, "a"));
assertEquals(
new HashSet<Tuple2<Integer, String>>(expected),
new HashSet<Tuple2<Integer, String>>(globalPartitionResult));
}
private static class SubtaskIndexAssigner extends RichMapFunction<Tuple1<String>, Tuple2<Integer, String>> {
private static final long serialVersionUID = 1L;
private int indexOfSubtask;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
RuntimeContext runtimeContext = getRuntimeContext();
indexOfSubtask = runtimeContext.getIndexOfThisSubtask();
}
@Override
public Tuple2<Integer, String> map(Tuple1<String> value) throws Exception {
return new Tuple2<Integer, String>(indexOfSubtask, value.f0);
}
}
}