/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.sql.planner.LogicalPlanner; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.common.collect.Iterables; import org.testng.annotations.Test; import java.util.List; import java.util.Map; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static java.util.stream.Collectors.toList; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; public class TestUnion extends BasePlanTest { public TestUnion() { super(); } public TestUnion(Map<String, String> sessionProperties) { super(sessionProperties); } @Test public void testSimpleUnion() { Plan plan = plan( "SELECT suppkey FROM supplier UNION ALL SELECT nationkey FROM nation", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false); List<PlanNode> remotes = searchFrom(plan.getRoot()) .where(TestUnion::isRemoteExchange) .findAll(); assertEquals(remotes.size(), 1, "There should be exactly one RemoteExchange"); assertEquals(((ExchangeNode) Iterables.getOnlyElement(remotes)).getType(), GATHER); assertPlanIsFullyDistributed(plan); } @Test public void testUnionOverSingleNodeAggregationAndUnion() { Plan plan = plan( "SELECT count(*) FROM (" + "SELECT 1 FROM nation GROUP BY regionkey " + "UNION ALL (" + " SELECT 1 FROM nation " + " UNION ALL " + " SELECT 1 FROM nation))", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false); List<PlanNode> remotes = searchFrom(plan.getRoot()) .where(TestUnion::isRemoteExchange) .findAll(); assertEquals(remotes.size(), 2, "There should be exactly two RemoteExchanges"); assertEquals(((ExchangeNode) remotes.get(0)).getType(), GATHER); assertEquals(((ExchangeNode) remotes.get(1)).getType(), REPARTITION); } @Test public void testPartialAggregationsWithUnion() { Plan plan = plan( "SELECT orderstatus, sum(orderkey) FROM (SELECT orderkey, orderstatus FROM orders UNION ALL SELECT orderkey, orderstatus FROM orders) x GROUP BY (orderstatus)", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false); assertAtMostOneAggregationBetweenRemoteExchanges(plan); assertPlanIsFullyDistributed(plan); } @Test public void testPartialRollupAggregationsWithUnion() { Plan plan = plan( "SELECT orderstatus, sum(orderkey) FROM (SELECT orderkey, orderstatus FROM orders UNION ALL SELECT orderkey, orderstatus FROM orders) x GROUP BY ROLLUP (orderstatus)", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false); assertAtMostOneAggregationBetweenRemoteExchanges(plan); assertPlanIsFullyDistributed(plan); } @Test public void testAggregationWithUnionAndValues() { Plan plan = plan( "SELECT regionkey, count(*) FROM (SELECT regionkey FROM nation UNION ALL SELECT * FROM (VALUES 2, 100) t(regionkey)) GROUP BY regionkey", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false); assertAtMostOneAggregationBetweenRemoteExchanges(plan); // TODO: Enable this check once distributed UNION can handle both partitioned and single node sources at the same time //assertPlanIsFullyDistributed(plan); } @Test public void testUnionOnProbeSide() { Plan plan = plan( "SELECT * FROM (SELECT * FROM nation UNION ALL SELECT * from nation) n, region r WHERE n.regionkey=r.regionkey", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false); assertPlanIsFullyDistributed(plan); } private void assertPlanIsFullyDistributed(Plan plan) { assertTrue( searchFrom(plan.getRoot()) .skipOnlyWhen(TestUnion::isNotRemoteGatheringExchange) .findAll() .stream() .noneMatch(planNode -> planNode instanceof AggregationNode || planNode instanceof JoinNode), "There is an Aggregation or Join between output and first REMOTE GATHER ExchangeNode"); List<PlanNode> gathers = searchFrom(plan.getRoot()) .where(TestUnion::isRemoteGatheringExchange) .findAll() .stream() .collect(toList()); assertEquals(gathers.size(), 1, "Only a single REMOTE GATHER was expected"); } private static void assertAtMostOneAggregationBetweenRemoteExchanges(Plan plan) { List<PlanNode> fragments = searchFrom(plan.getRoot()) .where(TestUnion::isRemoteExchange) .findAll() .stream() .flatMap(exchangeNode -> exchangeNode.getSources().stream()) .collect(toList()); for (PlanNode fragment : fragments) { List<PlanNode> aggregations = searchFrom(fragment) .where(AggregationNode.class::isInstance) .skipOnlyWhen(TestUnion::isNotRemoteExchange) .findAll(); assertFalse(aggregations.size() > 1, "More than a single AggregationNode between remote exchanges"); } } private static boolean isNotRemoteGatheringExchange(PlanNode planNode) { return !isRemoteGatheringExchange(planNode); } private static boolean isRemoteGatheringExchange(PlanNode planNode) { return isRemoteExchange(planNode) && ((ExchangeNode) planNode).getType().equals(GATHER); } private static boolean isNotRemoteExchange(PlanNode planNode) { return !isRemoteExchange(planNode); } private static boolean isRemoteExchange(PlanNode planNode) { return (planNode instanceof ExchangeNode) && ((ExchangeNode) planNode).getScope().equals(REMOTE); } }