/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.streaming.api; import static org.junit.Assert.assertEquals; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import com.google.common.collect.ImmutableList; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.common.operators.Keys; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.streaming.api.functions.aggregation.AggregationFunction.AggregationType; import org.apache.flink.streaming.api.functions.aggregation.ComparableAggregator; import org.apache.flink.streaming.api.functions.aggregation.SumAggregator; import org.apache.flink.streaming.api.operators.StreamGroupedReduce; import org.apache.flink.streaming.util.MockContext; import org.apache.flink.streaming.util.keys.KeySelectorUtil; import org.junit.Test; public class AggregationFunctionTest { @Test public void groupSumIntegerTest() throws Exception { // preparing expected outputs List<Tuple2<Integer, Integer>> expectedGroupSumList = new ArrayList<>(); List<Tuple2<Integer, Integer>> expectedGroupMinList = new ArrayList<>(); List<Tuple2<Integer, Integer>> expectedGroupMaxList = new ArrayList<>(); int groupedSum0 = 0; int groupedSum1 = 0; int groupedSum2 = 0; for (int i = 0; i < 9; i++) { int groupedSum; switch (i % 3) { case 0: groupedSum = groupedSum0 += i; break; case 1: groupedSum = groupedSum1 += i; break; default: groupedSum = groupedSum2 += i; break; } expectedGroupSumList.add(new Tuple2<>(i % 3, groupedSum)); expectedGroupMinList.add(new Tuple2<>(i % 3, i % 3)); expectedGroupMaxList.add(new Tuple2<>(i % 3, i)); } // some necessary boiler plate TypeInformation<Tuple2<Integer, Integer>> typeInfo = TypeExtractor.getForObject(new Tuple2<>(0, 0)); ExecutionConfig config = new ExecutionConfig(); KeySelector<Tuple2<Integer, Integer>, Tuple> keySelector = KeySelectorUtil.getSelectorForKeys( new Keys.ExpressionKeys<>(new int[]{0}, typeInfo), typeInfo, config); TypeInformation<Tuple> keyType = TypeExtractor.getKeySelectorTypes(keySelector, typeInfo); // aggregations tested ReduceFunction<Tuple2<Integer, Integer>> sumFunction = new SumAggregator<>(1, typeInfo, config); ReduceFunction<Tuple2<Integer, Integer>> minFunction = new ComparableAggregator<>( 1, typeInfo, AggregationType.MIN, config); ReduceFunction<Tuple2<Integer, Integer>> maxFunction = new ComparableAggregator<>( 1, typeInfo, AggregationType.MAX, config); List<Tuple2<Integer, Integer>> groupedSumList = MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(sumFunction, typeInfo.createSerializer(config)), getInputList(), keySelector, keyType); List<Tuple2<Integer, Integer>> groupedMinList = MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(minFunction, typeInfo.createSerializer(config)), getInputList(), keySelector, keyType); List<Tuple2<Integer, Integer>> groupedMaxList = MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(maxFunction, typeInfo.createSerializer(config)), getInputList(), keySelector, keyType); assertEquals(expectedGroupSumList, groupedSumList); assertEquals(expectedGroupMinList, groupedMinList); assertEquals(expectedGroupMaxList, groupedMaxList); } @Test public void pojoGroupSumIntegerTest() throws Exception { // preparing expected outputs List<MyPojo> expectedGroupSumList = new ArrayList<>(); List<MyPojo> expectedGroupMinList = new ArrayList<>(); List<MyPojo> expectedGroupMaxList = new ArrayList<>(); int groupedSum0 = 0; int groupedSum1 = 0; int groupedSum2 = 0; for (int i = 0; i < 9; i++) { int groupedSum; switch (i % 3) { case 0: groupedSum = groupedSum0 += i; break; case 1: groupedSum = groupedSum1 += i; break; default: groupedSum = groupedSum2 += i; break; } expectedGroupSumList.add(new MyPojo(i % 3, groupedSum)); expectedGroupMinList.add(new MyPojo(i % 3, i % 3)); expectedGroupMaxList.add(new MyPojo(i % 3, i)); } // some necessary boiler plate TypeInformation<MyPojo> typeInfo = TypeExtractor.getForObject(new MyPojo(0, 0)); ExecutionConfig config = new ExecutionConfig(); KeySelector<MyPojo, Tuple> keySelector = KeySelectorUtil.getSelectorForKeys( new Keys.ExpressionKeys<>(new String[]{"f0"}, typeInfo), typeInfo, config); TypeInformation<Tuple> keyType = TypeExtractor.getKeySelectorTypes(keySelector, typeInfo); // aggregations tested ReduceFunction<MyPojo> sumFunction = new SumAggregator<>("f1", typeInfo, config); ReduceFunction<MyPojo> minFunction = new ComparableAggregator<>("f1", typeInfo, AggregationType.MIN, false, config); ReduceFunction<MyPojo> maxFunction = new ComparableAggregator<>("f1", typeInfo, AggregationType.MAX, false, config); List<MyPojo> groupedSumList = MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(sumFunction, typeInfo.createSerializer(config)), getInputPojoList(), keySelector, keyType); List<MyPojo> groupedMinList = MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(minFunction, typeInfo.createSerializer(config)), getInputPojoList(), keySelector, keyType); List<MyPojo> groupedMaxList = MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(maxFunction, typeInfo.createSerializer(config)), getInputPojoList(), keySelector, keyType); assertEquals(expectedGroupSumList, groupedSumList); assertEquals(expectedGroupMinList, groupedMinList); assertEquals(expectedGroupMaxList, groupedMaxList); } @Test public void minMaxByTest() throws Exception { // Tuples are grouped on field 0, aggregated on field 1 // preparing expected outputs List<Tuple3<Integer, Integer, Integer>> maxByFirstExpected = ImmutableList.of( Tuple3.of(0,0,0), Tuple3.of(0,1,1), Tuple3.of(0,2,2), Tuple3.of(0,2,2), Tuple3.of(0,2,2), Tuple3.of(0,2,2), Tuple3.of(0,2,2), Tuple3.of(0,2,2), Tuple3.of(0,2,2)); List<Tuple3<Integer, Integer, Integer>> maxByLastExpected = ImmutableList.of( Tuple3.of(0, 0, 0), Tuple3.of(0, 1, 1), Tuple3.of(0, 2, 2), Tuple3.of(0, 2, 2), Tuple3.of(0, 2, 2), Tuple3.of(0, 2, 5), Tuple3.of(0, 2, 5), Tuple3.of(0, 2, 5), Tuple3.of(0, 2, 8)); List<Tuple3<Integer, Integer, Integer>> minByFirstExpected = ImmutableList.of( Tuple3.of(0,0,0), Tuple3.of(0,0,0), Tuple3.of(0,0,0), Tuple3.of(0,0,0), Tuple3.of(0,0,0), Tuple3.of(0,0,0), Tuple3.of(0,0,0), Tuple3.of(0,0,0), Tuple3.of(0,0,0)); List<Tuple3<Integer, Integer, Integer>> minByLastExpected = ImmutableList.of( Tuple3.of(0, 0, 0), Tuple3.of(0, 0, 0), Tuple3.of(0, 0, 0), Tuple3.of(0, 0, 3), Tuple3.of(0, 0, 3), Tuple3.of(0, 0, 3), Tuple3.of(0, 0, 6), Tuple3.of(0, 0, 6), Tuple3.of(0, 0, 6)); // some necessary boiler plate TypeInformation<Tuple3<Integer, Integer, Integer>> typeInfo = TypeExtractor .getForObject(Tuple3.of(0,0,0)); ExecutionConfig config = new ExecutionConfig(); KeySelector<Tuple3<Integer, Integer, Integer>, Tuple> keySelector = KeySelectorUtil.getSelectorForKeys( new Keys.ExpressionKeys<>(new int[]{0}, typeInfo), typeInfo, config); TypeInformation<Tuple> keyType = TypeExtractor.getKeySelectorTypes(keySelector, typeInfo); // aggregations tested ReduceFunction<Tuple3<Integer, Integer, Integer>> maxByFunctionFirst = new ComparableAggregator<>(1, typeInfo, AggregationType.MAXBY, true, config); ReduceFunction<Tuple3<Integer, Integer, Integer>> maxByFunctionLast = new ComparableAggregator<>(1, typeInfo, AggregationType.MAXBY, false, config); ReduceFunction<Tuple3<Integer, Integer, Integer>> minByFunctionFirst = new ComparableAggregator<>(1, typeInfo, AggregationType.MINBY, true, config); ReduceFunction<Tuple3<Integer, Integer, Integer>> minByFunctionLast = new ComparableAggregator<>(1, typeInfo, AggregationType.MINBY, false, config); assertEquals(maxByFirstExpected, MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(maxByFunctionFirst, typeInfo.createSerializer(config)), getInputByList(), keySelector, keyType)); assertEquals(maxByLastExpected, MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(maxByFunctionLast, typeInfo.createSerializer(config)), getInputByList(), keySelector, keyType)); assertEquals(minByLastExpected, MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(minByFunctionLast, typeInfo.createSerializer(config)), getInputByList(), keySelector, keyType)); assertEquals(minByFirstExpected, MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(minByFunctionFirst, typeInfo.createSerializer(config)), getInputByList(), keySelector, keyType)); } @Test public void pojoMinMaxByTest() throws Exception { // Pojos are grouped on field 0, aggregated on field 1 // preparing expected outputs List<MyPojo3> maxByFirstExpected = ImmutableList.of( new MyPojo3(0, 0), new MyPojo3(1, 1), new MyPojo3(2, 2), new MyPojo3(2, 2), new MyPojo3(2, 2), new MyPojo3(2, 2), new MyPojo3(2, 2), new MyPojo3(2, 2), new MyPojo3(2, 2)); List<MyPojo3> maxByLastExpected = ImmutableList.of( new MyPojo3(0, 0), new MyPojo3(1, 1), new MyPojo3(2, 2), new MyPojo3(2, 2), new MyPojo3(2, 2), new MyPojo3(2, 5), new MyPojo3(2, 5), new MyPojo3(2, 5), new MyPojo3(2, 8)); List<MyPojo3> minByFirstExpected = ImmutableList.of( new MyPojo3(0, 0), new MyPojo3(0, 0), new MyPojo3(0, 0), new MyPojo3(0, 0), new MyPojo3(0, 0), new MyPojo3(0, 0), new MyPojo3(0, 0), new MyPojo3(0, 0), new MyPojo3(0, 0)); List<MyPojo3> minByLastExpected = ImmutableList.of( new MyPojo3(0, 0), new MyPojo3(0, 0), new MyPojo3(0, 0), new MyPojo3(0, 3), new MyPojo3(0, 3), new MyPojo3(0, 3), new MyPojo3(0, 6), new MyPojo3(0, 6), new MyPojo3(0, 6)); // some necessary boiler plate TypeInformation<MyPojo3> typeInfo = TypeExtractor.getForObject(new MyPojo3(0, 0)); ExecutionConfig config = new ExecutionConfig(); KeySelector<MyPojo3, Tuple> keySelector = KeySelectorUtil.getSelectorForKeys( new Keys.ExpressionKeys<>(new String[]{"f0"}, typeInfo), typeInfo, config); TypeInformation<Tuple> keyType = TypeExtractor.getKeySelectorTypes(keySelector, typeInfo); // aggregations tested ReduceFunction<MyPojo3> maxByFunctionFirst = new ComparableAggregator<>("f1", typeInfo, AggregationType.MAXBY, true, config); ReduceFunction<MyPojo3> maxByFunctionLast = new ComparableAggregator<>("f1", typeInfo, AggregationType.MAXBY, false, config); ReduceFunction<MyPojo3> minByFunctionFirst = new ComparableAggregator<>("f1", typeInfo, AggregationType.MINBY, true, config); ReduceFunction<MyPojo3> minByFunctionLast = new ComparableAggregator<>("f1", typeInfo, AggregationType.MINBY, false, config); assertEquals(maxByFirstExpected, MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(maxByFunctionFirst, typeInfo.createSerializer(config)), getInputByPojoList(), keySelector, keyType)); assertEquals(maxByLastExpected, MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(maxByFunctionLast, typeInfo.createSerializer(config)), getInputByPojoList(), keySelector, keyType)); assertEquals(minByLastExpected, MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(minByFunctionLast, typeInfo.createSerializer(config)), getInputByPojoList(), keySelector, keyType)); assertEquals(minByFirstExpected, MockContext.createAndExecuteForKeyedStream( new StreamGroupedReduce<>(minByFunctionFirst, typeInfo.createSerializer(config)), getInputByPojoList(), keySelector, keyType)); } // ************************************************************************* // UTILS // ************************************************************************* private List<Tuple2<Integer, Integer>> getInputList() { ArrayList<Tuple2<Integer, Integer>> inputList = new ArrayList<>(); for (int i = 0; i < 9; i++) { inputList.add(Tuple2.of(i % 3, i)); } return inputList; } private List<MyPojo> getInputPojoList() { ArrayList<MyPojo> inputList = new ArrayList<>(); for (int i = 0; i < 9; i++) { inputList.add(new MyPojo(i % 3, i)); } return inputList; } private List<Tuple3<Integer, Integer, Integer>> getInputByList() { ArrayList<Tuple3<Integer, Integer, Integer>> inputList = new ArrayList<>(); for (int i = 0; i < 9; i++) { inputList.add(Tuple3.of(0, i % 3, i)); } return inputList; } private List<MyPojo3> getInputByPojoList() { ArrayList<MyPojo3> inputList = new ArrayList<>(); for (int i = 0; i < 9; i++) { inputList.add(new MyPojo3(i % 3, i)); } return inputList; } public static class MyPojo implements Serializable { private static final long serialVersionUID = 1L; public int f0; public int f1; public MyPojo(int f0, int f1) { this.f0 = f0; this.f1 = f1; } public MyPojo() { } @Override public String toString() { return "POJO(" + f0 + "," + f1 + ")"; } @Override public boolean equals(Object other) { if (other instanceof MyPojo) { return this.f0 == ((MyPojo) other).f0 && this.f1 == ((MyPojo) other).f1; } else { return false; } } } public static class MyPojo3 implements Serializable { private static final long serialVersionUID = 1L; public int f0; public int f1; public int f2; // Field 0 is always initialized to 0 public MyPojo3(int f1, int f2) { this.f1 = f1; this.f2 = f2; } public MyPojo3() { } @Override public String toString() { return "POJO3(" + f0 + "," + f1 + "," + f2 + ")"; } @Override public boolean equals(Object other) { if (other instanceof MyPojo3) { return this.f0 == ((MyPojo3) other).f0 && this.f1 == ((MyPojo3) other).f1 && this.f2 == ((MyPojo3) other).f2; } else { return false; } } } }