/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.operators;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntComparator;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.runtime.TupleComparator;
import org.apache.flink.api.java.typeutils.runtime.TupleSerializer;
import org.apache.flink.runtime.operators.testutils.DelayingIterator;
import org.apache.flink.runtime.operators.testutils.DiscardingOutputCollector;
import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
import org.apache.flink.runtime.operators.testutils.InfiniteIntTupleIterator;
import org.apache.flink.runtime.operators.testutils.UnaryOperatorTestBase;
import org.apache.flink.runtime.operators.testutils.UniformIntTupleGenerator;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MutableObjectIterator;
import org.junit.Test;
import java.util.ArrayList;
import static org.junit.Assert.*;
public class CombineTaskTest
extends UnaryOperatorTestBase<RichGroupReduceFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>>,
Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
private static final long COMBINE_MEM = 3 * 1024 * 1024;
private final double combine_frac;
private final ArrayList<Tuple2<Integer, Integer>> outList = new ArrayList<>();
@SuppressWarnings("unchecked")
private final TypeSerializer<Tuple2<Integer, Integer>> serializer = new TupleSerializer<>(
(Class<Tuple2<Integer, Integer>>) (Class<?>) Tuple2.class,
new TypeSerializer<?>[] { IntSerializer.INSTANCE, IntSerializer.INSTANCE });
private final TypeComparator<Tuple2<Integer, Integer>> comparator = new TupleComparator<>(
new int[]{0},
new TypeComparator<?>[] { new IntComparator(true) },
new TypeSerializer<?>[] { IntSerializer.INSTANCE });
public CombineTaskTest(ExecutionConfig config) {
super(config, COMBINE_MEM, 0);
combine_frac = (double)COMBINE_MEM / this.getMemoryManager().getMemorySize();
}
@Test
public void testCombineTask() {
try {
int keyCnt = 100;
int valCnt = 20;
setInput(new UniformIntTupleGenerator(keyCnt, valCnt, false), serializer);
addDriverComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(this.outList, serializer);
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
getTaskConfig().setRelativeMemoryDriver(combine_frac);
getTaskConfig().setFilehandlesDriver(2);
final GroupReduceCombineDriver<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> testTask =
new GroupReduceCombineDriver<>();
testDriver(testTask, MockCombiningReduceStub.class);
int expSum = 0;
for (int i = 1;i < valCnt; i++) {
expSum += i;
}
assertTrue(this.outList.size() == keyCnt);
for (Tuple2<Integer, Integer> record : this.outList) {
assertTrue(record.f1 == expSum);
}
this.outList.clear();
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testFailingCombineTask() {
try {
int keyCnt = 100;
int valCnt = 20;
setInput(new UniformIntTupleGenerator(keyCnt, valCnt, false), serializer);
addDriverComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(new DiscardingOutputCollector<Tuple2<Integer, Integer>>());
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
getTaskConfig().setRelativeMemoryDriver(combine_frac);
getTaskConfig().setFilehandlesDriver(2);
final GroupReduceCombineDriver<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> testTask =
new GroupReduceCombineDriver<>();
try {
testDriver(testTask, MockFailingCombiningReduceStub.class);
fail("Exception not forwarded.");
}
catch (ExpectedTestException etex) {
// good!
}
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testCancelCombineTaskSorting() {
try {
MutableObjectIterator<Tuple2<Integer, Integer>> slowInfiniteInput =
new DelayingIterator<>(new InfiniteIntTupleIterator(), 1);
setInput(slowInfiniteInput, serializer);
addDriverComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(new DiscardingOutputCollector<Tuple2<Integer, Integer>>());
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
getTaskConfig().setRelativeMemoryDriver(combine_frac);
getTaskConfig().setFilehandlesDriver(2);
final GroupReduceCombineDriver<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> testTask =
new GroupReduceCombineDriver<>();
Thread taskRunner = new Thread() {
@Override
public void run() {
try {
testDriver(testTask, MockFailingCombiningReduceStub.class);
}
catch (Exception e) {
// exceptions may happen during canceling
}
}
};
taskRunner.start();
// give the task some time
Thread.sleep(500);
// cancel
testTask.cancel();
// make sure it reacts to the canceling in some time
long deadline = System.currentTimeMillis() + 10000;
do {
taskRunner.interrupt();
taskRunner.join(5000);
}
while (taskRunner.isAlive() && System.currentTimeMillis() < deadline);
assertFalse("Task did not cancel properly within in 10 seconds.", taskRunner.isAlive());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
// ------------------------------------------------------------------------
// Test Combiners
// ------------------------------------------------------------------------
public static class MockCombiningReduceStub implements
GroupReduceFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>>,
GroupCombineFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>>
{
private static final long serialVersionUID = 1L;
@Override
public void reduce(Iterable<Tuple2<Integer, Integer>> records, Collector<Tuple2<Integer, Integer>> out) {
int key = 0;
int sum = 0;
for (Tuple2<Integer, Integer> next : records) {
key = next.f0;
sum += next.f1;
}
out.collect(new Tuple2<>(key, sum));
}
@Override
public void combine(Iterable<Tuple2<Integer, Integer>> records, Collector<Tuple2<Integer, Integer>> out) {
reduce(records, out);
}
}
public static final class MockFailingCombiningReduceStub implements
GroupReduceFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>>,
GroupCombineFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>>
{
private static final long serialVersionUID = 1L;
private int cnt;
@Override
public void reduce(Iterable<Tuple2<Integer, Integer>> records, Collector<Tuple2<Integer, Integer>> out) {
int key = 0;
int sum = 0;
for (Tuple2<Integer, Integer> next : records) {
key = next.f0;
sum += next.f1;
}
int resultValue = sum - key;
out.collect(new Tuple2<>(key, resultValue));
}
@Override
public void combine(Iterable<Tuple2<Integer, Integer>> records, Collector<Tuple2<Integer, Integer>> out) {
int key = 0;
int sum = 0;
for (Tuple2<Integer, Integer> next : records) {
key = next.f0;
sum += next.f1;
}
if (++this.cnt >= 10) {
throw new ExpectedTestException();
}
int resultValue = sum - key;
out.collect(new Tuple2<>(key, resultValue));
}
}
}