/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.api.java.batch.table;
import java.io.Serializable;
import java.util.List;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.table.api.java.BatchTableEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.api.java.tuple.Tuple7;
import org.apache.flink.types.Row;
import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.examples.java.WordCountTable.WC;
import org.apache.flink.test.javaApiOperators.util.CollectionDataSets;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@RunWith(Parameterized.class)
public class AggregationsITCase extends TableProgramsTestBase {
public AggregationsITCase(TestExecutionMode mode, TableConfigMode configMode){
super(mode, configMode);
}
@Test
public void testAggregationTypes() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
Table table = tableEnv.fromDataSet(CollectionDataSets.get3TupleDataSet(env));
Table result = table.select("f0.sum, f0.min, f0.max, f0.count, f0.avg");
DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "231,1,21,21,11";
compareResultAsText(results, expected);
}
@Test(expected = ValidationException.class)
public void testAggregationOnNonExistingField() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
Table table =
tableEnv.fromDataSet(CollectionDataSets.get3TupleDataSet(env));
Table result =
table.select("foo.avg");
DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "";
compareResultAsText(results, expected);
}
@Test
public void testWorkingAggregationDataTypes() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSource<Tuple7<Byte, Short, Integer, Long, Float, Double, String>> input =
env.fromElements(
new Tuple7<>((byte) 1, (short) 1, 1, 1L, 1.0f, 1.0d, "Hello"),
new Tuple7<>((byte) 2, (short) 2, 2, 2L, 2.0f, 2.0d, "Ciao"));
Table table = tableEnv.fromDataSet(input);
Table result =
table.select("f0.avg, f1.avg, f2.avg, f3.avg, f4.avg, f5.avg, f6.count");
DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "1,1,1,1,1.5,1.5,2";
compareResultAsText(results, expected);
}
@Test
public void testAggregationWithArithmetic() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSource<Tuple2<Float, String>> input =
env.fromElements(
new Tuple2<>(1f, "Hello"),
new Tuple2<>(2f, "Ciao"));
Table table =
tableEnv.fromDataSet(input);
Table result =
table.select("(f0 + 2).avg + 2, f1.count + 5");
DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "5.5,7";
compareResultAsText(results, expected);
}
@Test
public void testAggregationWithTwoCount() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSource<Tuple2<Float, String>> input =
env.fromElements(
new Tuple2<>(1f, "Hello"),
new Tuple2<>(2f, "Ciao"));
Table table =
tableEnv.fromDataSet(input);
Table result =
table.select("f0.count, f1.count");
DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "2,2";
compareResultAsText(results, expected);
}
@Test(expected = ValidationException.class)
public void testNonWorkingDataTypes() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSource<Tuple2<Float, String>> input = env.fromElements(new Tuple2<>(1f, "Hello"));
Table table =
tableEnv.fromDataSet(input);
Table result =
// Must fail. Cannot compute SUM aggregate on String field.
table.select("f1.sum");
DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "";
compareResultAsText(results, expected);
}
@Test(expected = ValidationException.class)
public void testNoNestedAggregation() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSource<Tuple2<Float, String>> input = env.fromElements(new Tuple2<>(1f, "Hello"));
Table table =
tableEnv.fromDataSet(input);
Table result =
// Must fail. Aggregation on aggregation not allowed.
table.select("f0.sum.sum");
DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "";
compareResultAsText(results, expected);
}
@Test(expected = ValidationException.class)
public void testGroupingOnNonExistentField() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSet<Tuple3<Integer, Long, String>> input = CollectionDataSets.get3TupleDataSet(env);
tableEnv
.fromDataSet(input, "a, b, c")
// must fail. Field foo is not in input
.groupBy("foo")
.select("a.avg");
}
@Test(expected = ValidationException.class)
public void testGroupingInvalidSelection() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSet<Tuple3<Integer, Long, String>> input = CollectionDataSets.get3TupleDataSet(env);
tableEnv
.fromDataSet(input, "a, b, c")
.groupBy("a, b")
// must fail. Field c is not a grouping key or aggregation
.select("c");
}
@Test
public void testGroupedAggregate() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSet<Tuple3<Integer, Long, String>> input = CollectionDataSets.get3TupleDataSet(env);
Table table = tableEnv.fromDataSet(input, "a, b, c");
Table result = table
.groupBy("b").select("b, a.sum");
DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "1,1\n" + "2,5\n" + "3,15\n" + "4,34\n" + "5,65\n" + "6,111\n";
compareResultAsText(results, expected);
}
@Test
public void testGroupingKeyForwardIfNotUsed() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSet<Tuple3<Integer, Long, String>> input = CollectionDataSets.get3TupleDataSet(env);
Table table = tableEnv.fromDataSet(input, "a, b, c");
Table result = table
.groupBy("b").select("a.sum");
DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "1\n" + "5\n" + "15\n" + "34\n" + "65\n" + "111\n";
compareResultAsText(results, expected);
}
@Test
public void testGroupNoAggregation() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSet<Tuple3<Integer, Long, String>> input = CollectionDataSets.get3TupleDataSet(env);
Table table = tableEnv.fromDataSet(input, "a, b, c");
Table result = table
.groupBy("b").select("a.sum as d, b").groupBy("b, d").select("b");
DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
String expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n";
List<Row> results = ds.collect();
compareResultAsText(results, expected);
}
@Test
public void testPojoAggregation() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSet<WC> input = env.fromElements(
new WC("Hello", 1),
new WC("Ciao", 1),
new WC("Hello", 1),
new WC("Hola", 1),
new WC("Hola", 1));
Table table = tableEnv.fromDataSet(input);
Table filtered = table
.groupBy("word")
.select("word.count as frequency, word")
.filter("frequency = 2");
List<String> result = tableEnv.toDataSet(filtered, WC.class)
.map(new MapFunction<WC, String>() {
public String map(WC value) throws Exception {
return value.word;
}
}).collect();
String expected = "Hello\n" + "Hola";
compareResultAsText(result, expected);
}
@Test
public void testPojoGrouping() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple3<String, Double, String>> data = env.fromElements(
new Tuple3<>("A", 23.0, "Z"),
new Tuple3<>("A", 24.0, "Y"),
new Tuple3<>("B", 1.0, "Z"));
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
Table table = tableEnv
.fromDataSet(data, "groupMe, value, name")
.select("groupMe, value, name")
.where("groupMe != 'B'");
DataSet<MyPojo> myPojos = tableEnv.toDataSet(table, MyPojo.class);
DataSet<MyPojo> result = myPojos.groupBy("groupMe")
.sortGroup("value", Order.DESCENDING)
.first(1);
List<MyPojo> resultList = result.collect();
compareResultAsText(resultList, "A,24.0,Y");
}
@Test
public void testDistinct() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSet<Tuple3<Integer, Long, String>> input = CollectionDataSets.get3TupleDataSet(env);
Table table = tableEnv.fromDataSet(input, "a, b, c");
Table distinct = table.select("b").distinct();
DataSet<Row> ds = tableEnv.toDataSet(distinct, Row.class);
List<Row> results = ds.collect();
String expected = "1\n" + "2\n" + "3\n"+ "4\n"+ "5\n"+ "6\n";
compareResultAsText(results, expected);
}
@Test
public void testDistinctAfterAggregate() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
DataSet<Tuple5<Integer, Long, Integer, String, Long>> input = CollectionDataSets.get5TupleDataSet(env);
Table table = tableEnv.fromDataSet(input, "a, b, c, d, e");
Table distinct = table.groupBy("a, e").select("e").distinct();
DataSet<Row> ds = tableEnv.toDataSet(distinct, Row.class);
List<Row> results = ds.collect();
String expected = "1\n" + "2\n" + "3\n";
compareResultAsText(results, expected);
}
// --------------------------------------------------------------------------------------------
public static class MyPojo implements Serializable {
private static final long serialVersionUID = 8741918940120107213L;
public String groupMe;
public double value;
public String name;
public MyPojo() {
// for serialization
}
public MyPojo(String groupMe, double value, String name) {
this.groupMe = groupMe;
this.value = value;
this.name = name;
}
@Override
public String toString() {
return groupMe + "," + value + "," + name;
}
}
}