/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.StatsRecorder; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; import com.facebook.presto.sql.planner.assertions.PlanAssert; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.testing.LocalQueryRunner; import com.facebook.presto.tpch.TpchConnectorFactory; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.intellij.lang.annotations.Language; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.util.List; import java.util.Map; import java.util.Optional; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anySymbol; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.groupingSet; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestMixedDistinctAggregationOptimizer { private LocalQueryRunner queryRunner; @BeforeClass public void setUp() { Session defaultSession = testSessionBuilder() .setCatalog("local") .setSchema(TINY_SCHEMA_NAME) .setSystemProperty(SystemSessionProperties.OPTIMIZE_DISTINCT_AGGREGATIONS, "true") .build(); queryRunner = new LocalQueryRunner(defaultSession); queryRunner.createCatalog(queryRunner.getDefaultSession().getCatalog().get(), new TpchConnectorFactory(1), ImmutableMap.of()); } @AfterClass(alwaysRun = true) public void tearDown() { closeAllRuntimeException(queryRunner); queryRunner = null; } @Test public void testMixedDistinctAggregationOptimizer() { @Language("SQL") String sql = "SELECT custkey, max(totalprice) AS s, count(DISTINCT orderdate) AS d FROM orders GROUP BY custkey"; String group = "GROUP"; // Original keys String groupBy = "CUSTKEY"; String aggregate = "TOTALPRICE"; String distinctAggregation = "ORDERDATE"; // Second Aggregation data List<String> groupByKeysSecond = ImmutableList.of(groupBy); Map<Optional<String>, ExpectedValueProvider<FunctionCall>> aggregationsSecond = ImmutableMap.of( Optional.of("arbitrary"), PlanMatchPattern.functionCall("arbitrary", false, ImmutableList.of(anySymbol())), Optional.of("count"), PlanMatchPattern.functionCall("count", false, ImmutableList.of(anySymbol()))); // First Aggregation data List<String> groupByKeysFirst = ImmutableList.of(groupBy, distinctAggregation, group); Map<Optional<String>, ExpectedValueProvider<FunctionCall>> aggregationsFirst = ImmutableMap.of( Optional.of("MAX"), functionCall("max", ImmutableList.of("TOTALPRICE"))); PlanMatchPattern tableScan = tableScan("orders", ImmutableMap.of("TOTALPRICE", "totalprice", "CUSTKEY", "custkey", "ORDERDATE", "orderdate")); // GroupingSet symbols ImmutableList.Builder<List<String>> groups = ImmutableList.builder(); groups.add(ImmutableList.of(groupBy, aggregate)); groups.add(ImmutableList.of(groupBy, distinctAggregation)); PlanMatchPattern expectedPlanPattern = anyTree( aggregation(ImmutableList.of(groupByKeysSecond), aggregationsSecond, ImmutableMap.of(), Optional.empty(), SINGLE, project( aggregation(ImmutableList.of(groupByKeysFirst), aggregationsFirst, ImmutableMap.of(), Optional.empty(), SINGLE, groupingSet(groups.build(), group, anyTree(tableScan)))))); assertUnitPlan(sql, expectedPlanPattern); } @Test public void testNestedType() { // Second Aggregation data Map<String, ExpectedValueProvider<FunctionCall>> aggregationsSecond = ImmutableMap.of( "arbitrary", PlanMatchPattern.functionCall("arbitrary", false, ImmutableList.of(anySymbol())), "count", PlanMatchPattern.functionCall("count", false, ImmutableList.of(anySymbol()))); // First Aggregation data Map<String, ExpectedValueProvider<FunctionCall>> aggregationsFirst = ImmutableMap.of( "max", PlanMatchPattern.functionCall("max", false, ImmutableList.of(anySymbol()))); assertUnitPlan("SELECT count(DISTINCT a), max(b) FROM (VALUES (ROW(1, 2), 3)) t(a, b)", anyTree( aggregation(aggregationsSecond, project( aggregation(aggregationsFirst, anyTree(values(ImmutableMap.of())) ))))); } public void assertUnitPlan(String sql, PlanMatchPattern pattern) { List<PlanOptimizer> optimizers = ImmutableList.of( new UnaliasSymbolReferences(), new IterativeOptimizer(new StatsRecorder(), ImmutableSet.of(new RemoveRedundantIdentityProjections())), new OptimizeMixedDistinctAggregations(queryRunner.getMetadata()), new PruneUnreferencedOutputs()); queryRunner.inTransaction(transactionSession -> { Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers); PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); return null; }); } }