/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Map; import java.util.Optional; import static com.facebook.presto.SystemSessionProperties.getTaskConcurrency; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static com.google.common.base.Verify.verify; import static com.google.common.collect.Iterables.getOnlyElement; /** * Adds INTERMEDIATE aggregations between an un-grouped FINAL aggregation and its preceding * PARTIAL aggregation. * <p> * From: * <pre> * - Aggregation (FINAL) * - RemoteExchange (GATHER) * - Aggregation (PARTIAL) * </pre> * To: * <pre> * - Aggregation (FINAL) * - LocalExchange (GATHER) * - Aggregation (INTERMEDIATE) * - LocalExchange (ARBITRARY) * - RemoteExchange (GATHER) * - Aggregation (INTERMEDIATE) * - LocalExchange (GATHER) * - Aggregation (PARTIAL) * </pre> * <p> */ public class AddIntermediateAggregations implements Rule { @Override public Optional<PlanNode> apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) { if (!SystemSessionProperties.isEnableIntermediateAggregations(session)) { return Optional.empty(); } if (!(node instanceof AggregationNode)) { return Optional.empty(); } AggregationNode aggregation = (AggregationNode) node; // Only consider FINAL un-grouped aggregations if (aggregation.getStep() != AggregationNode.Step.FINAL || !aggregation.getGroupingKeys().isEmpty()) { return Optional.empty(); } Optional<PlanNode> rewrittenSource = recurseToPartial(lookup.resolve(aggregation.getSource()), lookup, idAllocator); if (!rewrittenSource.isPresent()) { return Optional.empty(); } PlanNode source = rewrittenSource.get(); if (getTaskConcurrency(session) > 1) { source = ExchangeNode.partitionedExchange( idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source, new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), source.getOutputSymbols())); source = new AggregationNode( idAllocator.getNextId(), source, inputsAsOutputs(aggregation.getAssignments()), aggregation.getGroupingSets(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()); source = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source); } return Optional.of(node.replaceChildren(ImmutableList.of(source))); } /** * Recurse through a series of preceding ExchangeNodes and ProjectNodes to find the preceding PARTIAL aggregation */ private Optional<PlanNode> recurseToPartial(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator) { if (node instanceof AggregationNode && ((AggregationNode) node).getStep() == AggregationNode.Step.PARTIAL) { return Optional.of(addGatheringIntermediate((AggregationNode) node, idAllocator)); } if (!(node instanceof ExchangeNode) && !(node instanceof ProjectNode)) { return Optional.empty(); } ImmutableList.Builder<PlanNode> builder = ImmutableList.builder(); for (PlanNode source : node.getSources()) { Optional<PlanNode> planNode = recurseToPartial(lookup.resolve(source), lookup, idAllocator); if (!planNode.isPresent()) { return Optional.empty(); } builder.add(planNode.get()); } return Optional.of(node.replaceChildren(builder.build())); } private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator) { verify(aggregation.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation"); ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregation); return new AggregationNode( idAllocator.getNextId(), gatheringExchange, outputsAsInputs(aggregation.getAssignments()), aggregation.getGroupingSets(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()); } /** * Rewrite assignments so that inputs are in terms of the output symbols. * <p> * Example: * 'a' := sum('b') => 'a' := sum('a') * 'a' := count(*) => 'a' := count('a') */ private static Map<Symbol, AggregationNode.Aggregation> outputsAsInputs(Map<Symbol, AggregationNode.Aggregation> assignments) { ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> builder = ImmutableMap.builder(); for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) { Symbol output = entry.getKey(); AggregationNode.Aggregation aggregation = entry.getValue(); builder.put( output, new AggregationNode.Aggregation( new FunctionCall(QualifiedName.of(aggregation.getSignature().getName()), ImmutableList.of(output.toSymbolReference())), aggregation.getSignature(), Optional.empty())); // No mask for INTERMEDIATE } return builder.build(); } /** * Rewrite assignments so that outputs are in terms of the input symbols. * This operation only reliably applies to aggregation steps that take partial inputs (e.g. INTERMEDIATE and split FINALs), * which are guaranteed to have exactly one input and one output. * <p> * Example: * 'a' := sum('b') => 'b' := sum('b') */ private static Map<Symbol, AggregationNode.Aggregation> inputsAsOutputs(Map<Symbol, AggregationNode.Aggregation> assignments) { ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> builder = ImmutableMap.builder(); for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) { // Should only have one input symbol Symbol input = getOnlyElement(DependencyExtractor.extractAll(entry.getValue().getCall())); builder.put(input, entry.getValue()); } return builder.build(); } }