/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.fn.impl;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.drill.BaseTestQuery;
import org.apache.drill.PlanTestBase;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.common.types.TypeProtos;
import org.apache.drill.common.util.TestTools;
import org.apache.drill.exec.proto.UserBitShared;
import org.apache.drill.exec.rpc.user.QueryDataBatch;
import org.junit.Ignore;
import org.junit.Test;
import java.util.List;
import java.util.Map;
public class TestAggregateFunctions extends BaseTestQuery {
private static final String TEST_RES_PATH = TestTools.getWorkingPath() + "/src/test/resources";
/*
* Test checks the count of a nullable column within a map
* and verifies count is equal only to the number of times the
* column appears and doesn't include the null count
*/
@Test
public void testCountOnNullableColumn() throws Exception {
testBuilder()
.sqlQuery("select count(t.x.y) as cnt1, count(`integer`) as cnt2 from cp.`/jsoninput/input2.json` t")
.ordered()
.baselineColumns("cnt1", "cnt2")
.baselineValues(3l, 4l)
.build().run();
}
@Test
public void testCountDistinctOnBoolColumn() throws Exception {
testBuilder()
.sqlQuery("select count(distinct `bool_val`) as cnt from `sys`.`options`")
.ordered()
.baselineColumns("cnt")
.baselineValues(2l)
.build().run();
}
@Test
public void testMaxWithZeroInput() throws Exception {
testBuilder()
.sqlQuery("select max(employee_id * 0.0) as max_val from cp.`employee.json`")
.unOrdered()
.baselineColumns("max_val")
.baselineValues(0.0d)
.go();
}
@Ignore
@Test // DRILL-2092: count distinct, non distinct aggregate with group-by
public void testDrill2092() throws Exception {
String query = "select a1, b1, count(distinct c1) as dist1, \n"
+ "sum(c1) as sum1, count(c1) as cnt1, count(*) as cnt \n"
+ "from cp.`agg/bugs/drill2092/input.json` \n"
+ "group by a1, b1 order by a1, b1";
String baselineQuery =
"select case when columns[0]='null' then cast(null as bigint) else cast(columns[0] as bigint) end as a1, \n"
+ "case when columns[1]='null' then cast(null as bigint) else cast(columns[1] as bigint) end as b1, \n"
+ "case when columns[2]='null' then cast(null as bigint) else cast(columns[2] as bigint) end as dist1, \n"
+ "case when columns[3]='null' then cast(null as bigint) else cast(columns[3] as bigint) end as sum1, \n"
+ "case when columns[4]='null' then cast(null as bigint) else cast(columns[4] as bigint) end as cnt1, \n"
+ "case when columns[5]='null' then cast(null as bigint) else cast(columns[5] as bigint) end as cnt \n"
+ "from cp.`agg/bugs/drill2092/result.tsv`";
// NOTE: this type of query gets rewritten by Calcite into an inner join of subqueries, so
// we need to test with both hash join and merge join
testBuilder()
.sqlQuery(query)
.ordered()
.optionSettingQueriesForTestQuery("alter system set `planner.enable_hashjoin` = true")
.sqlBaselineQuery(baselineQuery)
.build().run();
testBuilder()
.sqlQuery(query)
.ordered()
.optionSettingQueriesForTestQuery("alter system set `planner.enable_hashjoin` = false")
.sqlBaselineQuery(baselineQuery)
.build().run();
}
@Test // DRILL-2170: Subquery has group-by, order-by on aggregate function and limit
public void testDrill2170() throws Exception {
String query =
"select count(*) as cnt from "
+ "cp.`tpch/orders.parquet` o inner join\n"
+ "(select l_orderkey, sum(l_quantity), sum(l_extendedprice) \n"
+ "from cp.`tpch/lineitem.parquet` \n"
+ "group by l_orderkey order by 3 limit 100) sq \n"
+ "on sq.l_orderkey = o.o_orderkey";
testBuilder()
.sqlQuery(query)
.ordered()
.optionSettingQueriesForTestQuery("alter system set `planner.slice_target` = 1000")
.baselineColumns("cnt")
.baselineValues(100l)
.build().run();
}
@Test // DRILL-2168
public void testGBExprWithDrillFunc() throws Exception {
testBuilder()
.ordered()
.sqlQuery("select concat(n_name, cast(n_nationkey as varchar(10))) as name, count(*) as cnt " +
"from cp.`tpch/nation.parquet` " +
"group by concat(n_name, cast(n_nationkey as varchar(10))) " +
"having concat(n_name, cast(n_nationkey as varchar(10))) > 'UNITED'" +
"order by concat(n_name, cast(n_nationkey as varchar(10)))")
.baselineColumns("name", "cnt")
.baselineValues("UNITED KINGDOM23", 1L)
.baselineValues("UNITED STATES24", 1L)
.baselineValues("VIETNAM21", 1L)
.build().run();
}
@Test //DRILL-2242
public void testDRILLNestedGBWithSubsetKeys() throws Exception {
String sql = " select count(*) as cnt from (select l_partkey from\n" +
" (select l_partkey, l_suppkey from cp.`tpch/lineitem.parquet`\n" +
" group by l_partkey, l_suppkey) \n" +
" group by l_partkey )";
test("alter session set `planner.slice_target` = 1; alter session set `planner.enable_multiphase_agg` = false ;");
testBuilder()
.ordered()
.sqlQuery(sql)
.baselineColumns("cnt")
.baselineValues(2000L)
.build().run();
test("alter session set `planner.slice_target` = 1; alter session set `planner.enable_multiphase_agg` = true ;");
testBuilder()
.ordered()
.sqlQuery(sql)
.baselineColumns("cnt")
.baselineValues(2000L)
.build().run();
test("alter session set `planner.slice_target` = 100000");
}
@Test
public void testAvgWithNullableScalarFunction() throws Exception {
String query = " select avg(length(b1)) as col from cp.`jsoninput/nullable1.json`";
testBuilder()
.sqlQuery(query)
.unOrdered()
.baselineColumns("col")
.baselineValues(3.0d)
.go();
}
@Test
public void testCountWithAvg() throws Exception {
testBuilder()
.sqlQuery("select count(a) col1, avg(b) col2 from cp.`jsoninput/nullable3.json`")
.unOrdered()
.baselineColumns("col1", "col2")
.baselineValues(2l, 3.0d)
.go();
testBuilder()
.sqlQuery("select count(a) col1, avg(a) col2 from cp.`jsoninput/nullable3.json`")
.unOrdered()
.baselineColumns("col1", "col2")
.baselineValues(2l, 1.0d)
.go();
}
@Test
public void testAvgOnKnownType() throws Exception {
testBuilder()
.sqlQuery("select avg(cast(employee_id as bigint)) as col from cp.`employee.json`")
.unOrdered()
.baselineColumns("col")
.baselineValues(578.9982683982684d)
.go();
}
@Test
public void testStddevOnKnownType() throws Exception {
testBuilder()
.sqlQuery("select stddev_samp(cast(employee_id as int)) as col from cp.`employee.json`")
.unOrdered()
.baselineColumns("col")
.baselineValues(333.56708470261117d)
.go();
}
@Test
// test aggregates when input is empty and data type is optional
public void countEmptyNullableInput() throws Exception {
String query = "select " +
"count(employee_id) col1, avg(employee_id) col2, sum(employee_id) col3 " +
"from cp.`employee.json` where 1 = 0";
testBuilder()
.sqlQuery(query)
.unOrdered()
.baselineColumns("col1", "col2", "col3")
.baselineValues(0l, null, null)
.go();
}
@Test
@Ignore("DRILL-4473")
public void sumEmptyNonexistentNullableInput() throws Exception {
final String query = "select "
+
"sum(int_col) col1, sum(bigint_col) col2, sum(float4_col) col3, sum(float8_col) col4, sum(interval_year_col) col5 "
+
"from cp.`employee.json` where 1 = 0";
testBuilder()
.sqlQuery(query)
.unOrdered()
.baselineColumns("col1", "col2", "col3", "col4", "col5")
.baselineValues(null, null, null, null, null)
.go();
}
@Test
@Ignore("DRILL-4473")
public void avgEmptyNonexistentNullableInput() throws Exception {
// test avg function
final String query = "select "
+
"avg(int_col) col1, avg(bigint_col) col2, avg(float4_col) col3, avg(float8_col) col4, avg(interval_year_col) col5 "
+
"from cp.`employee.json` where 1 = 0";
testBuilder()
.sqlQuery(query)
.unOrdered()
.baselineColumns("col1", "col2", "col3", "col4", "col5")
.baselineValues(null, null, null, null, null)
.go();
}
@Test
public void stddevEmptyNonexistentNullableInput() throws Exception {
// test stddev function
final String query = "select " +
"stddev_pop(int_col) col1, stddev_pop(bigint_col) col2, stddev_pop(float4_col) col3, " +
"stddev_pop(float8_col) col4, stddev_pop(interval_year_col) col5 " +
"from cp.`employee.json` where 1 = 0";
testBuilder()
.sqlQuery(query)
.unOrdered()
.baselineColumns("col1", "col2", "col3", "col4", "col5")
.baselineValues(null, null, null, null, null)
.go();
}
@Test
public void minMaxEmptyNonNullableInput() throws Exception {
// test min and max functions on required type
final QueryDataBatch result = testSqlWithResults("select * from cp.`parquet/alltypes_required.parquet` limit 0")
.get(0);
final Map<String, StringBuilder> functions = Maps.newHashMap();
functions.put("min", new StringBuilder());
functions.put("max", new StringBuilder());
final Map<String, Object> resultingValues = Maps.newHashMap();
for (UserBitShared.SerializedField field : result.getHeader().getDef().getFieldList()) {
final String fieldName = field.getNamePart().getName();
// Only COUNT aggregate function supported for Boolean type
if (fieldName.equals("col_bln")) {
continue;
}
resultingValues.put(String.format("`%s`", fieldName), null);
for (Map.Entry<String, StringBuilder> function : functions.entrySet()) {
function.getValue()
.append(function.getKey())
.append("(")
.append(fieldName)
.append(") ")
.append(fieldName)
.append(",");
}
}
result.release();
final String query = "select %s from cp.`parquet/alltypes_required.parquet` where 1 = 0";
final List<Map<String, Object>> baselineRecords = Lists.newArrayList();
baselineRecords.add(resultingValues);
for (StringBuilder selectBody : functions.values()) {
selectBody.setLength(selectBody.length() - 1);
testBuilder()
.sqlQuery(query, selectBody.toString())
.unOrdered()
.baselineRecords(baselineRecords)
.go();
}
}
/*
* Streaming agg on top of a filter produces wrong results if the first two batches are filtered out.
* In the below test we have three files in the input directory and since the ordering of reading
* of these files may not be deterministic, we have three tests to make sure we test the case where
* streaming agg gets two empty batches.
*/
@Test
public void drill3069() throws Exception {
final String query = "select max(foo) col1 from dfs_test.`%s/agg/bugs/drill3069` where foo = %d";
testBuilder()
.sqlQuery(String.format(query, TEST_RES_PATH, 2))
.unOrdered()
.baselineColumns("col1")
.baselineValues(2l)
.go();
testBuilder()
.sqlQuery(String.format(query, TEST_RES_PATH, 4))
.unOrdered()
.baselineColumns("col1")
.baselineValues(4l)
.go();
testBuilder()
.sqlQuery(String.format(query, TEST_RES_PATH, 6))
.unOrdered()
.baselineColumns("col1")
.baselineValues(6l)
.go();
}
@Test //DRILL-2748
public void testPushFilterPastAgg() throws Exception {
final String query =
" select cnt " +
" from (select n_regionkey, count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey) " +
" where n_regionkey = 2 ";
// Validate the plan
final String[] expectedPlan = {"(?s)(StreamAgg|HashAgg).*Filter"};
final String[] excludedPatterns = {"(?s)Filter.*(StreamAgg|HashAgg)"};
PlanTestBase.testPlanMatchingPatterns(query, expectedPlan, excludedPatterns);
testBuilder()
.sqlQuery(query)
.unOrdered()
.baselineColumns("cnt")
.baselineValues(5l)
.build().run();
// having clause
final String query2 =
" select count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey " +
" having n_regionkey = 2 ";
PlanTestBase.testPlanMatchingPatterns(query2, expectedPlan, excludedPatterns);
testBuilder()
.sqlQuery(query)
.unOrdered()
.baselineColumns("cnt")
.baselineValues(5l)
.build().run();
}
@Test
public void testPushFilterInExprPastAgg() throws Exception {
final String query =
" select cnt " +
" from (select n_regionkey, count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey) " +
" where n_regionkey + 100 - 100 = 2 ";
// Validate the plan
final String[] expectedPlan = {"(?s)(StreamAgg|HashAgg).*Filter"};
final String[] excludedPatterns = {"(?s)Filter.*(StreamAgg|HashAgg)"};
PlanTestBase.testPlanMatchingPatterns(query, expectedPlan, excludedPatterns);
testBuilder()
.sqlQuery(query)
.unOrdered()
.baselineColumns("cnt")
.baselineValues(5l)
.build().run();
}
@Test
public void testNegPushFilterInExprPastAgg() throws Exception {
// negative case: should not push filter, since it involves the aggregate result
final String query =
" select cnt " +
" from (select n_regionkey, count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey) " +
" where cnt + 100 - 100 = 5 ";
// Validate the plan
final String[] expectedPlan = {"(?s)Filter(?!StreamAgg|!HashAgg)"};
final String[] excludedPatterns = {"(?s)(StreamAgg|HashAgg).*Filter"};
PlanTestBase.testPlanMatchingPatterns(query, expectedPlan, excludedPatterns);
// negative case: should not push filter, since it is expression of group key + agg result.
final String query2 =
" select cnt " +
" from (select n_regionkey, count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey) " +
" where cnt + n_regionkey = 5 ";
PlanTestBase.testPlanMatchingPatterns(query2, expectedPlan, excludedPatterns);
}
@Test // DRILL-3781
// GROUP BY System functions in schema table.
public void testGroupBySystemFuncSchemaTable() throws Exception {
final String query = "select count(*) as cnt from sys.version group by CURRENT_DATE";
final String[] expectedPlan = {"(?s)(StreamAgg|HashAgg)"};
final String[] excludedPatterns = {};
PlanTestBase.testPlanMatchingPatterns(query, expectedPlan, excludedPatterns);
}
@Test //DRILL-3781
// GROUP BY System functions in csv, parquet, json table.
public void testGroupBySystemFuncFileSystemTable() throws Exception {
final String query = String.format("select count(*) as cnt from dfs_test.`%s/nation/nation.tbl` group by CURRENT_DATE", TEST_RES_PATH);
testBuilder()
.sqlQuery(query)
.unOrdered()
.baselineColumns("cnt")
.baselineValues(25l)
.build().run();
final String query2 = "select count(*) as cnt from cp.`tpch/nation.parquet` group by CURRENT_DATE";
testBuilder()
.sqlQuery(query2)
.unOrdered()
.baselineColumns("cnt")
.baselineValues(25l)
.build().run();
final String query3 = "select count(*) as cnt from cp.`employee.json` group by CURRENT_DATE";
testBuilder()
.sqlQuery(query3)
.unOrdered()
.baselineColumns("cnt")
.baselineValues(1155l)
.build().run();
}
@Test
public void test4443() throws Exception {
test("SELECT MIN(columns[1]) FROM dfs_test.`%s/agg/4443.csv` GROUP BY columns[0]", TEST_RES_PATH);
}
@Test
public void testCountStarRequired() throws Exception {
final String query = "select count(*) as col from cp.`tpch/region.parquet`";
List<Pair<SchemaPath, TypeProtos.MajorType>> expectedSchema = Lists.newArrayList();
TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder()
.setMinorType(TypeProtos.MinorType.BIGINT)
.setMode(TypeProtos.DataMode.REQUIRED)
.build();
expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType));
testBuilder()
.sqlQuery(query)
.schemaBaseLine(expectedSchema)
.build()
.run();
testBuilder()
.sqlQuery(query)
.unOrdered()
.baselineColumns("col")
.baselineValues(5l)
.build()
.run();
}
@Test // DRILL-4531
public void testPushFilterDown() throws Exception {
final String sql =
"SELECT cust.custAddress, \n"
+ " lineitem.provider \n"
+ "FROM ( \n"
+ " SELECT cast(c_custkey AS bigint) AS custkey, \n"
+ " c_address AS custAddress \n"
+ " FROM cp.`tpch/customer.parquet` ) cust \n"
+ "LEFT JOIN \n"
+ " ( \n"
+ " SELECT DISTINCT l_linenumber, \n"
+ " CASE \n"
+ " WHEN l_partkey IN (1, 2) THEN 'Store1'\n"
+ " WHEN l_partkey IN (5, 6) THEN 'Store2'\n"
+ " END AS provider \n"
+ " FROM cp.`tpch/lineitem.parquet` \n"
+ " WHERE ( l_orderkey >=20160101 AND l_partkey <=20160301) \n"
+ " AND l_partkey IN (1,2, 5, 6) ) lineitem\n"
+ "ON cust.custkey = lineitem.l_linenumber \n"
+ "WHERE provider IS NOT NULL \n"
+ "GROUP BY cust.custAddress, \n"
+ " lineitem.provider \n"
+ "ORDER BY cust.custAddress, \n"
+ " lineitem.provider";
// Validate the plan
final String[] expectedPlan = {"(?s)(Join).*inner"}; // With filter pushdown, left join will be converted into inner join
final String[] excludedPatterns = {"(?s)(Join).*(left)"};
PlanTestBase.testPlanMatchingPatterns(sql, expectedPlan, excludedPatterns);
}
@Test // DRILL-2385: count on complex objects failed with missing function implementation
public void testCountComplexObjects() throws Exception {
final String query = "select count(t.%s) %s from cp.`complex/json/complex.json` t";
Map<String, String> objectsMap = Maps.newHashMap();
objectsMap.put("COUNT_BIG_INT_REPEATED", "sia");
objectsMap.put("COUNT_FLOAT_REPEATED", "sfa");
objectsMap.put("COUNT_MAP_REPEATED", "soa");
objectsMap.put("COUNT_MAP_REQUIRED", "oooi");
objectsMap.put("COUNT_LIST_REPEATED", "odd");
objectsMap.put("COUNT_LIST_OPTIONAL", "sia");
for (String object: objectsMap.keySet()) {
String optionSetting = "";
if (object.equals("COUNT_LIST_OPTIONAL")) {
// if `exec.enable_union_type` parameter is true then BIGINT<REPEATED> object is converted to LIST<OPTIONAL> one
optionSetting = "alter session set `exec.enable_union_type`=true";
}
try {
testBuilder()
.sqlQuery(query, objectsMap.get(object), object)
.optionSettingQueriesForTestQuery(optionSetting)
.unOrdered()
.baselineColumns(object)
.baselineValues(3L)
.go();
} finally {
test("ALTER SESSION RESET `exec.enable_union_type`");
}
}
}
}